From 6969d6b8ba4e65ec22a37f299011aeda86f27c98 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Mon, 20 Jan 2025 14:06:18 +0100 Subject: [PATCH 01/39] move ua parser and type to ua module of rama-ua as to make room for also adding the profile types --- rama-ua/src/lib.rs | 44 ++++++----------------------- rama-ua/src/{ => ua}/info.rs | 0 rama-ua/src/ua/mod.rs | 33 ++++++++++++++++++++++ rama-ua/src/{ => ua}/parse.rs | 0 rama-ua/src/{ => ua}/parse_tests.rs | 0 5 files changed, 42 insertions(+), 35 deletions(-) rename rama-ua/src/{ => ua}/info.rs (100%) create mode 100644 rama-ua/src/ua/mod.rs rename rama-ua/src/{ => ua}/parse.rs (100%) rename rama-ua/src/{ => ua}/parse_tests.rs (100%) diff --git a/rama-ua/src/lib.rs b/rama-ua/src/lib.rs index 94d4fb20..26feccf9 100644 --- a/rama-ua/src/lib.rs +++ b/rama-ua/src/lib.rs @@ -1,9 +1,14 @@ -//! User Agent (UA) parser and types. +//! User Agent (UA) parser and profiles. //! -//! This module provides a parser ([`UserAgent::new`]) for User Agents +//! This crate provides a parser ([`UserAgent::new`]) for User Agents //! as well as a classifier (`UserAgentClassifierLayer` in `rama_http`) that can be used to //! classify incoming requests based on their User Agent (header). //! +//! These can be used to know what UA is connecting to a server, +//! but it can also be used to emulate the UA from a client +//! via the profiles that are found in this crate as well, +//! be it builtin modules or custom ones. +//! //! Learn more about User Agents (UA) and why Rama supports it //! at . //! @@ -55,36 +60,5 @@ #![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))] -use serde::{Deserialize, Serialize}; - -mod info; -pub use info::{ - DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind, -}; - -mod parse; -use parse::parse_http_user_agent_header; - -/// Information that can be used to overwrite the [`UserAgent`] of an http request. -/// -/// Used by the `UserAgentClassifier` (see `rama-http`) to overwrite the specified -/// information duing the classification of the [`UserAgent`]. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct UserAgentOverwrites { - /// Overwrite the [`UserAgent`] of the http `Request` with a custom value. - /// - /// This value will be used instead of - /// the 'User-Agent' http (header) value. - /// - /// This is useful in case you cannot set the User-Agent header in your request. - pub ua: Option, - /// Overwrite the [`HttpAgent`] of the http `Request` with a custom value. - pub http: Option, - /// Overwrite the [`TlsAgent`] of the http `Request` with a custom value. - pub tls: Option, - /// Preserve the original [`UserAgent`] header of the http `Request`. - pub preserve_ua: Option, -} - -#[cfg(test)] -mod parse_tests; +mod ua; +pub use ua::*; diff --git a/rama-ua/src/info.rs b/rama-ua/src/ua/info.rs similarity index 100% rename from rama-ua/src/info.rs rename to rama-ua/src/ua/info.rs diff --git a/rama-ua/src/ua/mod.rs b/rama-ua/src/ua/mod.rs new file mode 100644 index 00000000..d397001c --- /dev/null +++ b/rama-ua/src/ua/mod.rs @@ -0,0 +1,33 @@ +use serde::{Deserialize, Serialize}; + +mod info; +pub use info::{ + DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind, +}; + +mod parse; +use parse::parse_http_user_agent_header; + +/// Information that can be used to overwrite the [`UserAgent`] of an http request. +/// +/// Used by the `UserAgentClassifier` (see `rama-http`) to overwrite the specified +/// information duing the classification of the [`UserAgent`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct UserAgentOverwrites { + /// Overwrite the [`UserAgent`] of the http `Request` with a custom value. + /// + /// This value will be used instead of + /// the 'User-Agent' http (header) value. + /// + /// This is useful in case you cannot set the User-Agent header in your request. + pub ua: Option, + /// Overwrite the [`HttpAgent`] of the http `Request` with a custom value. + pub http: Option, + /// Overwrite the [`TlsAgent`] of the http `Request` with a custom value. + pub tls: Option, + /// Preserve the original [`UserAgent`] header of the http `Request`. + pub preserve_ua: Option, +} + +#[cfg(test)] +mod parse_tests; diff --git a/rama-ua/src/parse.rs b/rama-ua/src/ua/parse.rs similarity index 100% rename from rama-ua/src/parse.rs rename to rama-ua/src/ua/parse.rs diff --git a/rama-ua/src/parse_tests.rs b/rama-ua/src/ua/parse_tests.rs similarity index 100% rename from rama-ua/src/parse_tests.rs rename to rama-ua/src/ua/parse_tests.rs From bff7d85ada6fa16dce4b64e7703e703c5c62b7fa Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Tue, 21 Jan 2025 15:15:31 +0100 Subject: [PATCH 02/39] start defining UA profile types and prepare for encoding --- Cargo.lock | 5 + Cargo.toml | 1 + rama-net/Cargo.toml | 1 + rama-net/src/tls/client/hello/mod.rs | 6 +- rama-net/src/tls/enums/mod.rs | 111 +++++++++++++++++++++ rama-ua/Cargo.toml | 9 ++ rama-ua/src/lib.rs | 3 + rama-ua/src/profile/http.rs | 144 +++++++++++++++++++++++++++ rama-ua/src/profile/mod.rs | 10 ++ rama-ua/src/profile/tls.rs | 10 ++ rama-ua/src/profile/ua.rs | 29 ++++++ 11 files changed, 327 insertions(+), 2 deletions(-) create mode 100644 rama-ua/src/profile/http.rs create mode 100644 rama-ua/src/profile/mod.rs create mode 100644 rama-ua/src/profile/tls.rs create mode 100644 rama-ua/src/profile/ua.rs diff --git a/Cargo.lock b/Cargo.lock index 493f1252..c5219f1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2357,6 +2357,7 @@ dependencies = [ "rama-utils", "rustls", "serde", + "serde_json", "sha2", "tokio", "tokio-test", @@ -2431,11 +2432,15 @@ dependencies = [ name = "rama-ua" version = "0.2.0-alpha.7" dependencies = [ + "bytes", "rama-core", + "rama-http-types", + "rama-net", "rama-utils", "serde", "serde_json", "tokio", + "venndb", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 87e99c50..97424a9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -229,6 +229,7 @@ http-full = ["http", "tcp", "dep:rama-http-backend", "dep:rama-http-core"] proxy = ["dep:rama-proxy"] haproxy = ["dep:rama-haproxy"] ua = ["dep:rama-ua"] +ua-memory-db = ["ua", "rama-ua/memory-db"] proxy-memory-db = ["proxy", "rama-proxy/memory-db", "rama-net/venndb"] proxy-live-update = ["proxy", "rama-proxy/live-update"] proxy-csv = ["proxy", "rama-proxy/csv"] diff --git a/rama-net/Cargo.toml b/rama-net/Cargo.toml index 1c8c1006..5ee767fb 100644 --- a/rama-net/Cargo.toml +++ b/rama-net/Cargo.toml @@ -50,6 +50,7 @@ venndb = { workspace = true, optional = true } itertools = { workspace = true } nom = { workspace = true } quickcheck = { workspace = true } +serde_json = { workspace = true } tokio = { workspace = true, features = ["full"] } tokio-test = { workspace = true } diff --git a/rama-net/src/tls/client/hello/mod.rs b/rama-net/src/tls/client/hello/mod.rs index 6d8d67fd..25e31d8b 100644 --- a/rama-net/src/tls/client/hello/mod.rs +++ b/rama-net/src/tls/client/hello/mod.rs @@ -1,3 +1,5 @@ +use serde::{Deserialize, Serialize}; + use crate::address::Host; use crate::tls::{ enums::CompressionAlgorithm, ApplicationProtocol, CipherSuite, ECPointFormat, ExtensionId, @@ -10,7 +12,7 @@ mod rustls; #[cfg(feature = "boring")] mod boring; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] /// When a client first connects to a server, it is required to send /// the ClientHello as its first message. /// @@ -125,7 +127,7 @@ impl ClientHello { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] /// Extensions that can be set in a [`ClientHello`] message by a TLS client. /// /// While its name may infer that an extension is by definition optional, diff --git a/rama-net/src/tls/enums/mod.rs b/rama-net/src/tls/enums/mod.rs index 38b51ac9..44a141b2 100644 --- a/rama-net/src/tls/enums/mod.rs +++ b/rama-net/src/tls/enums/mod.rs @@ -66,6 +66,27 @@ macro_rules! enum_builder { ::std::fmt::UpperHex::fmt(&u8::from(*self), f) } } + + impl ::serde::Serialize for $enum_name { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + u8::from(*self).serialize(serializer) + } + } + + impl<'de> ::serde::Deserialize<'de> for $enum_name { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: ::serde::Deserializer<'de>, + { + let n = u8::deserialize(deserializer)?; + Ok(n.into()) + } + } }; ( $(#[$comment:meta])* @@ -133,6 +154,27 @@ macro_rules! enum_builder { ::std::fmt::UpperHex::fmt(&u16::from(*self), f) } } + + impl ::serde::Serialize for $enum_name { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + u16::from(*self).serialize(serializer) + } + } + + impl<'de> ::serde::Deserialize<'de> for $enum_name { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: ::serde::Deserializer<'de>, + { + let n = u16::deserialize(deserializer)?; + Ok(n.into()) + } + } }; ( $(#[$comment:meta])* @@ -256,6 +298,34 @@ macro_rules! enum_builder { } } } + + impl ::serde::Serialize for $enum_name { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + match self { + $( $enum_name::$enum_var => { + $enum_val.serialize(serializer) + }),* + ,$enum_name::Unknown(x) => { + x.serialize(serializer) + } + } + } + } + + impl<'de> ::serde::Deserialize<'de> for $enum_name { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: ::serde::Deserializer<'de>, + { + let b = <::std::borrow::Cow<'de, [u8]>>::deserialize(deserializer)?; + Ok(b.as_ref().into()) + } + } }; } @@ -1021,4 +1091,45 @@ mod tests { assert_eq!(12, r.position()); assert_eq!(&INPUT.as_bytes()[3..12], b"\x08http/1.1"); } + + #[test] + fn test_enum_u8_serialize_deserialize() { + let p: ECPointFormat = serde_json::from_str( + &serde_json::to_string(&ECPointFormat::ANSIX962CompressedChar2).unwrap(), + ) + .unwrap(); + assert_eq!(ECPointFormat::ANSIX962CompressedChar2, p); + + let p: ECPointFormat = + serde_json::from_str(&serde_json::to_string(&ECPointFormat::from(42u8)).unwrap()) + .unwrap(); + assert_eq!(ECPointFormat::from(42u8), p); + } + + #[test] + fn test_enum_u16_serialize_deserialize() { + let p: SupportedGroup = + serde_json::from_str(&serde_json::to_string(&SupportedGroup::BRAINPOOLP384R1).unwrap()) + .unwrap(); + assert_eq!(SupportedGroup::BRAINPOOLP384R1, p); + + let p: SupportedGroup = + serde_json::from_str(&serde_json::to_string(&SupportedGroup::from(0xffffu16)).unwrap()) + .unwrap(); + assert_eq!(SupportedGroup::from(0xffffu16), p); + } + + #[test] + fn test_enum_bytes_serialize_deserialize() { + let p: ApplicationProtocol = + serde_json::from_str(&serde_json::to_string(&ApplicationProtocol::HTTP_3).unwrap()) + .unwrap(); + assert_eq!(ApplicationProtocol::HTTP_3, p); + + let p: ApplicationProtocol = serde_json::from_str( + &serde_json::to_string(&ApplicationProtocol::from(b"foobar")).unwrap(), + ) + .unwrap(); + assert_eq!(ApplicationProtocol::from(b"foobar"), p); + } } diff --git a/rama-ua/Cargo.toml b/rama-ua/Cargo.toml index 664f271a..f802e950 100644 --- a/rama-ua/Cargo.toml +++ b/rama-ua/Cargo.toml @@ -13,10 +13,19 @@ rust-version = { workspace = true } [lints] workspace = true +[features] +default = [] +memory-db = ["dep:venndb"] +tls = ["dep:rama-net", "rama-net/tls"] + [dependencies] +bytes = { workspace = true } rama-core = { version = "0.2.0-alpha.7", path = "../rama-core" } +rama-http-types = { version = "0.2.0-alpha.7", path = "../rama-http-types" } +rama-net = { version = "0.2.0-alpha.7", path = "../rama-net", optional = true } rama-utils = { version = "0.2.0-alpha.7", path = "../rama-utils" } serde = { workspace = true, features = ["derive"] } +venndb = { workspace = true, optional = true } [dev-dependencies] serde_json = { workspace = true } diff --git a/rama-ua/src/lib.rs b/rama-ua/src/lib.rs index 26feccf9..7c155539 100644 --- a/rama-ua/src/lib.rs +++ b/rama-ua/src/lib.rs @@ -62,3 +62,6 @@ mod ua; pub use ua::*; + +mod profile; +pub use profile::*; diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs new file mode 100644 index 00000000..9c92a8b8 --- /dev/null +++ b/rama-ua/src/profile/http.rs @@ -0,0 +1,144 @@ +use rama_core::error::OpaqueError; +use rama_http_types::proto::h2::PseudoHeader; +use serde::{Deserialize, Serialize}; +use std::str::FromStr; + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[cfg_attr(feature = "memory-db", derive(venndb::VennDB))] +pub struct HttpProfile { + #[cfg_attr(feature = "memory-db", venndb(key))] + pub ja4h: String, + pub http_headers: Vec<(String, String)>, + pub http_pseudo_headers: Vec, + #[cfg_attr(feature = "memory-db", venndb(filter))] + pub fetch_mode: Option, + #[cfg_attr(feature = "memory-db", venndb(filter))] + pub resource_type: Option, + #[cfg_attr(feature = "memory-db", venndb(filter))] + pub initiator: Option, + #[cfg_attr(feature = "memory-db", venndb(filter))] + pub http_version: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +pub enum FetchMode { + Cors, + Navigate, + NoCors, + SameOrigin, + Websocket, +} + +impl std::fmt::Display for FetchMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Cors => write!(f, "cors"), + Self::Navigate => write!(f, "navigate"), + Self::NoCors => write!(f, "no-cors"), + Self::SameOrigin => write!(f, "same-origin"), + Self::Websocket => write!(f, "websocket"), + } + } +} + +impl FromStr for FetchMode { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "cors" => Ok(Self::Cors), + "navigate" => Ok(Self::Navigate), + "no-cors" => Ok(Self::NoCors), + "same-origin" => Ok(Self::SameOrigin), + "websocket" => Ok(Self::Websocket), + _ => Err(s.to_owned()), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +pub enum ResourceType { + Document, + Xhr, + Form, +} + +impl std::fmt::Display for ResourceType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Document => write!(f, "document"), + Self::Xhr => write!(f, "xhr"), + Self::Form => write!(f, "form"), + } + } +} + +impl FromStr for ResourceType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "document" => Ok(Self::Document), + "xhr" => Ok(Self::Xhr), + "form" => Ok(Self::Form), + _ => Err(s.to_owned()), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +pub enum Initiator { + Navigator, + Fetch, + XMLHttpRequest, + Form, +} + +impl std::fmt::Display for Initiator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Navigator => write!(f, "navigator"), + Self::Fetch => write!(f, "fetch"), + Self::XMLHttpRequest => write!(f, "xmlhttprequest"), + Self::Form => write!(f, "form"), + } + } +} + +impl FromStr for Initiator { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "navigator" => Ok(Self::Navigator), + "fetch" => Ok(Self::Fetch), + "xmlhttprequest" => Ok(Self::XMLHttpRequest), + "form" => Ok(Self::Form), + _ => Err(s.to_owned()), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +pub enum HttpVersion { + H1, + H2, + H3, +} + +impl FromStr for HttpVersion { + type Err = OpaqueError; + + fn from_str(s: &str) -> Result { + Ok(match s.trim().to_lowercase().as_str() { + "h1" | "http1" | "http/1" | "http/1.0" | "http/1.1" => Self::H1, + "h2" | "http2" | "http/2" | "http/2.0" => Self::H2, + "h3" | "http3" | "http/3" | "http/3.0" => Self::H3, + version => { + return Err(OpaqueError::from_display(format!( + "unsupported http version: {version}" + ))) + } + }) + } +} diff --git a/rama-ua/src/profile/mod.rs b/rama-ua/src/profile/mod.rs new file mode 100644 index 00000000..0be22223 --- /dev/null +++ b/rama-ua/src/profile/mod.rs @@ -0,0 +1,10 @@ +mod ua; +pub use ua::*; + +mod http; +pub use http::*; + +#[cfg(feature = "tls")] +mod tls; +#[cfg(feature = "tls")] +pub use tls::*; diff --git a/rama-ua/src/profile/tls.rs b/rama-ua/src/profile/tls.rs new file mode 100644 index 00000000..d99cd35e --- /dev/null +++ b/rama-ua/src/profile/tls.rs @@ -0,0 +1,10 @@ +use rama_net::tls::client::ClientHello; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[cfg_attr(feature = "memory-db", derive(venndb::VennDB))] +pub struct TlsProfile { + #[cfg_attr(feature = "memory-db", venndb(key))] + pub ja4: String, + pub client_hello: ClientHello, +} diff --git a/rama-ua/src/profile/ua.rs b/rama-ua/src/profile/ua.rs new file mode 100644 index 00000000..a6e1358f --- /dev/null +++ b/rama-ua/src/profile/ua.rs @@ -0,0 +1,29 @@ +use crate::{DeviceKind, UserAgentKind}; + +#[derive(Debug)] +#[cfg_attr(feature = "memory-db", derive(venndb::VennDB))] +pub struct UserAgentProfile { + #[cfg_attr(feature = "memory-db", venndb(key))] + pub header: String, + #[cfg_attr(feature = "memory-db", venndb(filter))] + pub kind: Option, + #[cfg_attr(feature = "memory-db", venndb(filter))] + pub platform_kind: Option, + #[cfg_attr(feature = "memory-db", venndb(filter))] + pub device_kind: Option, + pub version: Option, + + #[cfg(feature = "memory-db")] + pub http_profiles: crate::HttpProfileDB, + #[cfg(not(feature = "memory-db"))] + pub http_profiles: Vec, + + #[cfg(all(feature = "tls", feature = "memory-db"))] + pub tls_profiles: crate::TlsProfileDB, + #[cfg(all(feature = "tls", not(feature = "memory-db")))] + pub tls_profiles: Vec, +} + +// TODO support serialize / deseralize fo this struct and its property types +// TODO implement querying profiles +// TODO add query tests From 1bf044ab43624813259918353d50c4352719d89e Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sun, 26 Jan 2025 16:23:42 +0100 Subject: [PATCH 03/39] wip comments --- rama-ua/src/profile/ua.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rama-ua/src/profile/ua.rs b/rama-ua/src/profile/ua.rs index a6e1358f..ed38e675 100644 --- a/rama-ua/src/profile/ua.rs +++ b/rama-ua/src/profile/ua.rs @@ -27,3 +27,11 @@ pub struct UserAgentProfile { // TODO support serialize / deseralize fo this struct and its property types // TODO implement querying profiles // TODO add query tests +// +// TODO: do we really need VennDB here, we might be better off flattening it different... +// also we need to take into account the market spread +// +// TODO should we strip out heavy duplicate data? e.g. there is probably a lot of duplication +// in the TlsProfileData (inner) and HttpProfileData (inner) +// +// TODO: do we need to really take into account initiator, fetch and Resource type? From 5accf8cf220b1d24ef967a0b9c8f3aab057d724f Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Thu, 6 Feb 2025 15:17:33 +0100 Subject: [PATCH 04/39] implement ua db + improve types next steps: - add unit tests - refactor rama-fp to make use of this - add storage trait in rama-fp - store data in postgres for rama-fp (@fly.io) - expose data in website - test & iterate - ship embedding profiles automatically via a script into rama-ua we'll do in a follow-up PR --- Cargo.lock | 10 +- Cargo.toml | 2 +- rama-ua/Cargo.toml | 5 +- rama-ua/src/profile/db.rs | 238 +++++++++++++++++++++++++++++++++++ rama-ua/src/profile/http.rs | 129 ++++++++----------- rama-ua/src/profile/mod.rs | 6 +- rama-ua/src/profile/tls.rs | 22 +++- rama-ua/src/profile/ua.rs | 37 ------ rama-ua/src/ua/info.rs | 239 ++++++++++++++++++++++++------------ 9 files changed, 483 insertions(+), 205 deletions(-) create mode 100644 rama-ua/src/profile/db.rs delete mode 100644 rama-ua/src/profile/ua.rs diff --git a/Cargo.lock b/Cargo.lock index 0975724a..90507d99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1212,6 +1212,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "highway" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9040319a6910b901d5d49cbada4a99db52836a1b63228a05f7e2b7f8feef89b1" + [[package]] name = "home" version = "0.5.11" @@ -2477,14 +2483,16 @@ name = "rama-ua" version = "0.2.0-alpha.7" dependencies = [ "bytes", + "highway", "rama-core", "rama-http-types", "rama-net", "rama-utils", + "rand 0.9.0", "serde", "serde_json", "tokio", - "venndb", + "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 8b83b009..5e5acf8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -149,6 +149,7 @@ want = "0.3" futures-util = "0.3" futures-channel = "0.3" sha2 = "0.10.8" +highway = "1.3.0" [workspace.lints.rust] unreachable_pub = "deny" @@ -230,7 +231,6 @@ http-full = ["http", "tcp", "dep:rama-http-backend", "dep:rama-http-core"] proxy = ["dep:rama-proxy"] haproxy = ["dep:rama-haproxy"] ua = ["dep:rama-ua"] -ua-memory-db = ["ua", "rama-ua/memory-db"] proxy-memory-db = ["proxy", "rama-proxy/memory-db", "rama-net/venndb"] proxy-live-update = ["proxy", "rama-proxy/live-update"] proxy-csv = ["proxy", "rama-proxy/csv"] diff --git a/rama-ua/Cargo.toml b/rama-ua/Cargo.toml index f802e950..febfb6b3 100644 --- a/rama-ua/Cargo.toml +++ b/rama-ua/Cargo.toml @@ -15,7 +15,6 @@ workspace = true [features] default = [] -memory-db = ["dep:venndb"] tls = ["dep:rama-net", "rama-net/tls"] [dependencies] @@ -25,7 +24,9 @@ rama-http-types = { version = "0.2.0-alpha.7", path = "../rama-http-types" } rama-net = { version = "0.2.0-alpha.7", path = "../rama-net", optional = true } rama-utils = { version = "0.2.0-alpha.7", path = "../rama-utils" } serde = { workspace = true, features = ["derive"] } -venndb = { workspace = true, optional = true } +rand = { workspace = true } +tracing = { workspace = true } +highway = { workspace = true } [dev-dependencies] serde_json = { workspace = true } diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs new file mode 100644 index 00000000..a004a856 --- /dev/null +++ b/rama-ua/src/profile/db.rs @@ -0,0 +1,238 @@ +use std::collections::HashMap; +use rand::{distr::{weighted::WeightedIndex, Distribution as _}, seq::{IndexedRandom as _, IteratorRandom as _}}; +use serde::{Deserialize, Serialize}; + +use crate::{PlatformKind, UserAgentKind, Initiator}; + + +#[derive(Debug, Default)] +pub struct UserAgentDatabase { + profiles: HashMap, + + http_profiles: HashMap, + + #[cfg(feature = "tls")] + tls_profiles: HashMap, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct UserAgentProfileKey { + pub ua_kind: UserAgentKind, + pub ua_kind_version: usize, + pub platform_kind: PlatformKind, +} + +#[derive(Debug)] +struct UserAgentProfile { + pub ua_kind: UserAgentKind, + pub platform_kind: PlatformKind, + pub http_profiles: Vec, + + #[cfg(feature = "tls")] + pub tls_profiles: Vec, +} + +impl UserAgentProfile { + fn match_filters(&self, kind_mask: u8, platform_mask: u8) -> bool { + if self.http_profiles.is_empty() { + return false; + } + + #[cfg(feature = "tls")] + if self.tls_profiles.is_empty() { + return false; + } + + self.ua_kind as u8 & kind_mask != 0 && self.platform_kind as u8 & platform_mask != 0 + } +} + +impl UserAgentDatabase { + /// Create a new user agent database. + #[inline] + pub fn new() -> Self { + Self::default() + } +} + +#[derive(Debug, Clone, Default)] +pub struct UserAgentFilter { + pub kind: u8, + pub platform: u8, + pub initiator: Option, +} + +#[derive(Debug, Clone)] +pub struct UserAgentProfileQueryResult<'a> { + pub http: &'a crate::HttpProfile, + + #[cfg(feature = "tls")] + pub tls: &'a crate::TlsProfile, +} + +#[derive(Serialize, Deserialize)] +struct UserAgentFilterSerde { + kind: Option>, + platform: Option>, + initiator: Option, +} + +impl Serialize for UserAgentFilter { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + let mut kinds = Vec::new(); + if self.kind | UserAgentKind::Chromium as u8 != 0 { + kinds.push(UserAgentKind::Chromium); + } + if self.kind | UserAgentKind::Firefox as u8 != 0 { + kinds.push(UserAgentKind::Firefox); + } + if self.kind | UserAgentKind::Safari as u8 != 0 { + kinds.push(UserAgentKind::Safari); + } + + let mut platforms = Vec::new(); + if self.platform | PlatformKind::Windows as u8 != 0 { + platforms.push(PlatformKind::Windows); + } + if self.platform | PlatformKind::MacOS as u8 != 0 { + platforms.push(PlatformKind::MacOS); + } + if self.platform | PlatformKind::Linux as u8 != 0 { + platforms.push(PlatformKind::Linux); + } + if self.platform | PlatformKind::Android as u8 != 0 { + platforms.push(PlatformKind::Android); + } + if self.platform | PlatformKind::IOS as u8 != 0 { + platforms.push(PlatformKind::IOS); + } + + let filter = UserAgentFilterSerde { + kind: if kinds.is_empty() { None } else { Some(kinds) }, + platform: if platforms.is_empty() { None } else { Some(platforms) }, + initiator: self.initiator.clone(), + }; + filter.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for UserAgentFilter { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + let filter = UserAgentFilterSerde::deserialize(deserializer)?; + let mut result = UserAgentFilter::default(); + if let Some(kinds) = filter.kind { + for kind in kinds { + result.kind |= kind as u8; + } + } + if let Some(platforms) = filter.platform { + for platform in platforms { + result.platform |= platform as u8; + } + } + if let Some(initiator) = filter.initiator { + result.initiator = Some(initiator); + } + Ok(result) + } +} + +impl UserAgentDatabase { + pub fn insert_http_profile(&mut self, profile: crate::UserAgentHttpProfile) { + let key = profile.key(); + self.profiles.entry(UserAgentProfileKey { + ua_kind: profile.ua_kind, + ua_kind_version: profile.ua_kind_version, + platform_kind: profile.platform_kind, + }).or_insert_with(|| UserAgentProfile { + ua_kind: profile.ua_kind, + platform_kind: profile.platform_kind, + http_profiles: Vec::new(), + #[cfg(feature = "tls")] + tls_profiles: Vec::new(), + }).http_profiles.push(key); + self.http_profiles.insert(key, profile.http); + } + + #[cfg(feature = "tls")] + pub fn insert_tls_profile(&mut self, profile: crate::UserAgentTlsProfile) { + let key = profile.key(); + self.profiles.entry(UserAgentProfileKey { + ua_kind: profile.ua_kind, + ua_kind_version: profile.ua_kind_version, + platform_kind: profile.platform_kind, + }).or_insert_with(|| UserAgentProfile { + ua_kind: profile.ua_kind, + platform_kind: profile.platform_kind, + http_profiles: Vec::new(), + tls_profiles: Vec::new(), + }).tls_profiles.push(key); + } + + pub fn query( + &self, + filters: Option, + ) -> Option> { + let filter = filters.unwrap_or_default(); + let mut rng = rand::rng(); + + let kind_mask = if filter.kind == 0 { + tracing::trace!("no kind filter provided, using all"); + u8::MAX + } else { + filter.kind + }; + + let platform_mask = if filter.platform == 0 { + tracing::trace!("no platform filter provided, using all"); + u8::MAX + } else { + filter.platform + }; + + let profiles: Vec<_> = self.profiles.values() + .filter(|profile| profile.match_filters(kind_mask, platform_mask)) + .collect(); + if profiles.is_empty() { + tracing::debug!(?filter, "no profiles found for provided filters"); + return None; + } else { + tracing::trace!(?filter, "found {} profile(s) for provided filters", profiles.len()); + } + + // market share from https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) + let weights: Vec = profiles.iter().map(|profiles| match profiles.ua_kind { + UserAgentKind::Firefox => 0.03, + UserAgentKind::Safari => 0.18, + UserAgentKind::Chromium => 0.79, + }).collect(); + let dist = WeightedIndex::new(&weights).ok()?; + let profile = profiles.get(dist.sample(&mut rng))?; + + // try to get random http profile with initiator if defined, else random http profile + let http_profile_index = if let Some(initiator) = filter.initiator { + profile.http_profiles.iter().filter(|key| self.http_profiles.get(key).map(|http| http.initiator == initiator).unwrap_or(false)).choose(&mut rng) + } else { + profile.http_profiles.choose(&mut rng) + }?; + let http_profile = self.http_profiles.get(http_profile_index)?; + + + #[cfg(feature = "tls")] + profile.tls_profiles.choose(&mut rng).and_then(|key| { + self.tls_profiles.get(key) + })?; + + Some(UserAgentProfileQueryResult{ + http: http_profile, + #[cfg(feature = "tls")] + tls: tls_profile, + }) + } +} diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs index 9c92a8b8..5d026589 100644 --- a/rama-ua/src/profile/http.rs +++ b/rama-ua/src/profile/http.rs @@ -1,92 +1,37 @@ +use highway::HighwayHasher; use rama_core::error::OpaqueError; use rama_http_types::proto::h2::PseudoHeader; use serde::{Deserialize, Serialize}; -use std::str::FromStr; +use std::{borrow::Cow, hash::{Hash as _, Hasher as _}, str::FromStr}; -#[derive(Debug, Clone, Deserialize, Serialize)] -#[cfg_attr(feature = "memory-db", derive(venndb::VennDB))] -pub struct HttpProfile { - #[cfg_attr(feature = "memory-db", venndb(key))] - pub ja4h: String, - pub http_headers: Vec<(String, String)>, - pub http_pseudo_headers: Vec, - #[cfg_attr(feature = "memory-db", venndb(filter))] - pub fetch_mode: Option, - #[cfg_attr(feature = "memory-db", venndb(filter))] - pub resource_type: Option, - #[cfg_attr(feature = "memory-db", venndb(filter))] - pub initiator: Option, - #[cfg_attr(feature = "memory-db", venndb(filter))] - pub http_version: Option, -} +use crate::{PlatformKind, UserAgentKind}; -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] -pub enum FetchMode { - Cors, - Navigate, - NoCors, - SameOrigin, - Websocket, +#[derive(Debug, Clone, Serialize, Deserialize, Hash)] +pub struct UserAgentHttpProfile { + pub ua_kind: UserAgentKind, + pub ua_kind_version: usize, + pub platform_kind: PlatformKind, + pub http: HttpProfile, } -impl std::fmt::Display for FetchMode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Cors => write!(f, "cors"), - Self::Navigate => write!(f, "navigate"), - Self::NoCors => write!(f, "no-cors"), - Self::SameOrigin => write!(f, "same-origin"), - Self::Websocket => write!(f, "websocket"), - } - } -} - -impl FromStr for FetchMode { - type Err = String; - - fn from_str(s: &str) -> Result { - match s { - "cors" => Ok(Self::Cors), - "navigate" => Ok(Self::Navigate), - "no-cors" => Ok(Self::NoCors), - "same-origin" => Ok(Self::SameOrigin), - "websocket" => Ok(Self::Websocket), - _ => Err(s.to_owned()), - } +impl UserAgentHttpProfile { + pub fn key(&self) -> u64 { + let mut hasher = HighwayHasher::default(); + self.hash(&mut hasher); + hasher.finish() } } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] -pub enum ResourceType { - Document, - Xhr, - Form, -} - -impl std::fmt::Display for ResourceType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Document => write!(f, "document"), - Self::Xhr => write!(f, "xhr"), - Self::Form => write!(f, "form"), - } - } -} - -impl FromStr for ResourceType { - type Err = String; - - fn from_str(s: &str) -> Result { - match s { - "document" => Ok(Self::Document), - "xhr" => Ok(Self::Xhr), - "form" => Ok(Self::Form), - _ => Err(s.to_owned()), - } - } +#[derive(Debug, Clone, Deserialize, Serialize, Hash)] +pub struct HttpProfile { + pub ja4h: String, + pub http_headers: Vec<(String, String)>, + pub http_pseudo_headers: Vec, + pub initiator: Initiator, + pub http_version: HttpVersion, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] pub enum Initiator { Navigator, Fetch, @@ -119,13 +64,23 @@ impl FromStr for Initiator { } } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum HttpVersion { H1, H2, H3, } +impl HttpVersion { + pub fn as_str(&self) -> &'static str { + match self { + Self::H1 => "http/1", + Self::H2 => "h2", + Self::H3 => "h3", + } + } +} + impl FromStr for HttpVersion { type Err = OpaqueError; @@ -142,3 +97,21 @@ impl FromStr for HttpVersion { }) } } + +impl Serialize for HttpVersion { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for HttpVersion { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> { + let s = >::deserialize(deserializer)?; + HttpVersion::from_str(&s).map_err(serde::de::Error::custom) + } +} + diff --git a/rama-ua/src/profile/mod.rs b/rama-ua/src/profile/mod.rs index 0be22223..f3b07e5d 100644 --- a/rama-ua/src/profile/mod.rs +++ b/rama-ua/src/profile/mod.rs @@ -1,6 +1,3 @@ -mod ua; -pub use ua::*; - mod http; pub use http::*; @@ -8,3 +5,6 @@ pub use http::*; mod tls; #[cfg(feature = "tls")] pub use tls::*; + +mod db; +pub use db::*; diff --git a/rama-ua/src/profile/tls.rs b/rama-ua/src/profile/tls.rs index d99cd35e..7c686e3b 100644 --- a/rama-ua/src/profile/tls.rs +++ b/rama-ua/src/profile/tls.rs @@ -1,10 +1,26 @@ use rama_net::tls::client::ClientHello; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Deserialize, Serialize)] -#[cfg_attr(feature = "memory-db", derive(venndb::VennDB))] +use highway::HighwayHasher; + +#[derive(Debug, Clone, Serialize, Deserialize, Hash)] +pub struct UserAgentTlsProfile { + pub ua_kind: UserAgentKind, + pub ua_kind_version: usize, + pub platform_kind: PlatformKind, + pub tls: TlsProfile, +} + +impl UserAgentTlsProfile { + pub fn key(&self) -> u64 { + let mut hasher = HighwayHasher::default(); + self.hash(&mut hasher); + hasher.finish() + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, Hash)] pub struct TlsProfile { - #[cfg_attr(feature = "memory-db", venndb(key))] pub ja4: String, pub client_hello: ClientHello, } diff --git a/rama-ua/src/profile/ua.rs b/rama-ua/src/profile/ua.rs deleted file mode 100644 index ed38e675..00000000 --- a/rama-ua/src/profile/ua.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::{DeviceKind, UserAgentKind}; - -#[derive(Debug)] -#[cfg_attr(feature = "memory-db", derive(venndb::VennDB))] -pub struct UserAgentProfile { - #[cfg_attr(feature = "memory-db", venndb(key))] - pub header: String, - #[cfg_attr(feature = "memory-db", venndb(filter))] - pub kind: Option, - #[cfg_attr(feature = "memory-db", venndb(filter))] - pub platform_kind: Option, - #[cfg_attr(feature = "memory-db", venndb(filter))] - pub device_kind: Option, - pub version: Option, - - #[cfg(feature = "memory-db")] - pub http_profiles: crate::HttpProfileDB, - #[cfg(not(feature = "memory-db"))] - pub http_profiles: Vec, - - #[cfg(all(feature = "tls", feature = "memory-db"))] - pub tls_profiles: crate::TlsProfileDB, - #[cfg(all(feature = "tls", not(feature = "memory-db")))] - pub tls_profiles: Vec, -} - -// TODO support serialize / deseralize fo this struct and its property types -// TODO implement querying profiles -// TODO add query tests -// -// TODO: do we really need VennDB here, we might be better off flattening it different... -// also we need to take into account the market spread -// -// TODO should we strip out heavy duplicate data? e.g. there is probably a lot of duplication -// in the TlsProfileData (inner) and HttpProfileData (inner) -// -// TODO: do we need to really take into account initiator, fetch and Resource type? diff --git a/rama-ua/src/ua/info.rs b/rama-ua/src/ua/info.rs index 25c4c441..27de84a6 100644 --- a/rama-ua/src/ua/info.rs +++ b/rama-ua/src/ua/info.rs @@ -171,84 +171,193 @@ impl FromStr for UserAgent { /// The kind of [`UserAgent`] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u8)] pub enum UserAgentKind { /// Chromium Browser - Chromium, + Chromium = 0b0000_0001, /// Firefox Browser - Firefox, + Firefox = 0b0000_0010, /// Safari Browser - Safari, + Safari = 0b0000_0100, +} + +impl UserAgentKind { + pub fn as_str(&self) -> &'static str { + match self { + UserAgentKind::Chromium => "Chromium", + UserAgentKind::Firefox => "Firefox", + UserAgentKind::Safari => "Safari", + } + } } impl fmt::Display for UserAgentKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - UserAgentKind::Chromium => write!(f, "Chromium"), - UserAgentKind::Firefox => write!(f, "Firefox"), - UserAgentKind::Safari => write!(f, "Safari"), + write!(f, "{}", self.as_str()) + } +} + +impl FromStr for UserAgentKind { + type Err = OpaqueError; + + fn from_str(s: &str) -> Result { + match_ignore_ascii_case_str! { + match (s) { + "chromium" => Ok(UserAgentKind::Chromium), + "firefox" => Ok(UserAgentKind::Firefox), + "safari" => Ok(UserAgentKind::Safari), + _ => Err(OpaqueError::from_display(format!("invalid user agent kind: {}", s))), + } } } } +impl Serialize for UserAgentKind { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for UserAgentKind { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = >::deserialize(deserializer)?; + s.parse::().map_err(serde::de::Error::custom) + } +} + /// Device on which the [`UserAgent`] operates. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u8)] pub enum DeviceKind { /// Personal Computers - Desktop, + Desktop = 0b0000_0001, /// Phones, Tablets and other mobile devices - Mobile, + Mobile = 0b0000_0010, } -impl fmt::Display for DeviceKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl DeviceKind { + pub fn as_str(&self) -> &'static str { match self { - DeviceKind::Desktop => write!(f, "Desktop"), - DeviceKind::Mobile => write!(f, "Mobile"), + DeviceKind::Desktop => "Desktop", + DeviceKind::Mobile => "Mobile", } } } +impl fmt::Display for DeviceKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + /// Platform within the [`UserAgent`] operates. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u8)] pub enum PlatformKind { /// Windows Platform ([`Desktop`](DeviceKind::Desktop)) - Windows, + Windows = 0b0000_0001, /// MacOS Platform ([`Desktop`](DeviceKind::Desktop)) - MacOS, + MacOS = 0b0000_0010, /// Linux Platform ([`Desktop`](DeviceKind::Desktop)) - Linux, + Linux = 0b0000_0100, /// Android Platform ([`Mobile`](DeviceKind::Mobile)) - Android, + Android = 0b0001_0000, /// iOS Platform ([`Mobile`](DeviceKind::Mobile)) - IOS, + IOS = 0b0010_0000, } -impl fmt::Display for PlatformKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl PlatformKind { + pub fn as_str(&self) -> &'static str { match self { - PlatformKind::Windows => write!(f, "Windows"), - PlatformKind::MacOS => write!(f, "MacOS"), - PlatformKind::Linux => write!(f, "Linux"), - PlatformKind::Android => write!(f, "Android"), - PlatformKind::IOS => write!(f, "iOS"), + PlatformKind::Windows => "Windows", + PlatformKind::MacOS => "MacOS", + PlatformKind::Linux => "Linux", + PlatformKind::Android => "Android", + PlatformKind::IOS => "iOS", + } + } +} + +impl FromStr for PlatformKind { + type Err = OpaqueError; + + fn from_str(s: &str) -> Result { + match_ignore_ascii_case_str! { + match (s) { + "windows" => Ok(PlatformKind::Windows), + "macos" => Ok(PlatformKind::MacOS), + "linux" => Ok(PlatformKind::Linux), + "android" => Ok(PlatformKind::Android), + "ios" => Ok(PlatformKind::IOS), + _ => Err(OpaqueError::from_display(format!("invalid platform: {}", s))), + } } } } +impl Serialize for PlatformKind { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for PlatformKind { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = >::deserialize(deserializer)?; + s.parse::().map_err(serde::de::Error::custom) + } +} + +impl fmt::Display for PlatformKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + /// Http implementation used by the [`UserAgent`] #[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[repr(u8)] pub enum HttpAgent { /// Chromium based browsers share the same http implementation - Chromium, + Chromium = 0b0000_0001, /// Firefox has its own http implementation - Firefox, + Firefox = 0b0000_0010, /// Safari also has its own http implementation - Safari, + Safari = 0b0000_0100, /// Preserve the incoming Http Agent as much as possible. /// /// For emulators this means that emulators will aim to have a /// hands-off approach to the incoming http request. - Preserve, + Preserve = 0b1000_0000, +} + +impl HttpAgent { + pub fn as_str(&self) -> &'static str { + match self { + HttpAgent::Chromium => "Chromium", + HttpAgent::Firefox => "Firefox", + HttpAgent::Safari => "Safari", + HttpAgent::Preserve => "Preserve", + } + } +} + +impl fmt::Display for HttpAgent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } } impl Serialize for HttpAgent { @@ -256,12 +365,7 @@ impl Serialize for HttpAgent { where S: serde::ser::Serializer, { - match self { - HttpAgent::Chromium => serializer.serialize_str("Chromium"), - HttpAgent::Firefox => serializer.serialize_str("Firefox"), - HttpAgent::Safari => serializer.serialize_str("Safari"), - HttpAgent::Preserve => serializer.serialize_str("Preserve"), - } + serializer.serialize_str(self.as_str()) } } @@ -271,15 +375,7 @@ impl<'de> Deserialize<'de> for HttpAgent { D: Deserializer<'de>, { let s = >::deserialize(deserializer)?; - match_ignore_ascii_case_str! { - match (s) { - "chrome" | "chromium" => Ok(HttpAgent::Chromium), - "Firefox" => Ok(HttpAgent::Firefox), - "Safari" => Ok(HttpAgent::Safari), - "preserve" => Ok(HttpAgent::Preserve), - _ => Err(serde::de::Error::custom("invalid http agent")), - } - } + s.parse::().map_err(serde::de::Error::custom) } } @@ -299,57 +395,48 @@ impl FromStr for HttpAgent { } } -impl fmt::Display for HttpAgent { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - HttpAgent::Chromium => write!(f, "Chromium"), - HttpAgent::Firefox => write!(f, "Firefox"), - HttpAgent::Safari => write!(f, "Safari"), - HttpAgent::Preserve => write!(f, "Preserve"), - } - } -} - /// Tls implementation used by the [`UserAgent`] #[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[repr(u8)] pub enum TlsAgent { /// Rustls is used as a fallback for all user agents, /// that are not chromium based. - Rustls, + Rustls = 0b0000_0001, /// Boringssl is used for Chromium based user agents. - Boringssl, + Boringssl = 0b0000_0010, /// NSS is used for Firefox - Nss, + Nss = 0b0000_0100, /// Preserve the incoming TlsAgent as much as possible. /// /// For this Tls this means that emulators can try to /// preserve details of the incoming Tls connection /// such as the (Tls) Client Hello. - Preserve, + Preserve = 0b1000_0000, } -impl fmt::Display for TlsAgent { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl TlsAgent { + pub fn as_str(&self) -> &'static str { match self { - TlsAgent::Rustls => write!(f, "Rustls"), - TlsAgent::Boringssl => write!(f, "Boringssl"), - TlsAgent::Nss => write!(f, "NSS"), - TlsAgent::Preserve => write!(f, "Preserve"), + TlsAgent::Rustls => "Rustls", + TlsAgent::Boringssl => "Boringssl", + TlsAgent::Nss => "NSS", + TlsAgent::Preserve => "Preserve", } } } +impl fmt::Display for TlsAgent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + impl Serialize for TlsAgent { fn serialize(&self, serializer: S) -> Result where S: serde::ser::Serializer, { - match self { - TlsAgent::Rustls => serializer.serialize_str("Rustls"), - TlsAgent::Boringssl => serializer.serialize_str("Boringssl"), - TlsAgent::Nss => serializer.serialize_str("NSS"), - TlsAgent::Preserve => serializer.serialize_str("Preserve"), - } + serializer.serialize_str(self.as_str()) } } @@ -359,15 +446,7 @@ impl<'de> Deserialize<'de> for TlsAgent { D: Deserializer<'de>, { let s = >::deserialize(deserializer)?; - match_ignore_ascii_case_str! { - match (s) { - "rustls" => Ok(TlsAgent::Rustls), - "boring" | "boringssl" => Ok(TlsAgent::Boringssl), - "nss" => Ok(TlsAgent::Nss), - "preserve" => Ok(TlsAgent::Preserve), - _ => Err(serde::de::Error::custom("invalid tls agent")), - } - } + s.parse::().map_err(serde::de::Error::custom) } } From effd81c07bf42bb30aa44ef9284976c3ac47cfaa Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Thu, 6 Feb 2025 19:02:29 +0100 Subject: [PATCH 05/39] fix QA lints (UA) --- rama-net/src/tls/client/hello/mod.rs | 4 +- rama-net/src/tls/enums/mod.rs | 2 +- rama-ua/Cargo.toml | 4 +- rama-ua/src/profile/db.rs | 118 +++++++++++++++++---------- rama-ua/src/profile/http.rs | 13 ++- rama-ua/src/profile/tls.rs | 4 + 6 files changed, 93 insertions(+), 52 deletions(-) diff --git a/rama-net/src/tls/client/hello/mod.rs b/rama-net/src/tls/client/hello/mod.rs index 25e31d8b..80c248e6 100644 --- a/rama-net/src/tls/client/hello/mod.rs +++ b/rama-net/src/tls/client/hello/mod.rs @@ -12,7 +12,7 @@ mod rustls; #[cfg(feature = "boring")] mod boring; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Hash)] /// When a client first connects to a server, it is required to send /// the ClientHello as its first message. /// @@ -127,7 +127,7 @@ impl ClientHello { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Hash)] /// Extensions that can be set in a [`ClientHello`] message by a TLS client. /// /// While its name may infer that an extension is by definition optional, diff --git a/rama-net/src/tls/enums/mod.rs b/rama-net/src/tls/enums/mod.rs index 44a141b2..9495f3b4 100644 --- a/rama-net/src/tls/enums/mod.rs +++ b/rama-net/src/tls/enums/mod.rs @@ -184,7 +184,7 @@ macro_rules! enum_builder { ) => { $(#[$comment])* #[non_exhaustive] - #[derive(Debug, PartialEq, Eq, Clone)] + #[derive(Debug, PartialEq, Eq, Clone, Hash)] $enum_vis enum $enum_name { $( $enum_var),* ,Unknown(Vec) diff --git a/rama-ua/Cargo.toml b/rama-ua/Cargo.toml index febfb6b3..f3925594 100644 --- a/rama-ua/Cargo.toml +++ b/rama-ua/Cargo.toml @@ -19,14 +19,14 @@ tls = ["dep:rama-net", "rama-net/tls"] [dependencies] bytes = { workspace = true } +highway = { workspace = true } rama-core = { version = "0.2.0-alpha.7", path = "../rama-core" } rama-http-types = { version = "0.2.0-alpha.7", path = "../rama-http-types" } rama-net = { version = "0.2.0-alpha.7", path = "../rama-net", optional = true } rama-utils = { version = "0.2.0-alpha.7", path = "../rama-utils" } -serde = { workspace = true, features = ["derive"] } rand = { workspace = true } +serde = { workspace = true, features = ["derive"] } tracing = { workspace = true } -highway = { workspace = true } [dev-dependencies] serde_json = { workspace = true } diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index a004a856..6019804c 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -1,16 +1,18 @@ -use std::collections::HashMap; -use rand::{distr::{weighted::WeightedIndex, Distribution as _}, seq::{IndexedRandom as _, IteratorRandom as _}}; +use rand::{ + distr::{weighted::WeightedIndex, Distribution as _}, + seq::{IndexedRandom as _, IteratorRandom as _}, +}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; -use crate::{PlatformKind, UserAgentKind, Initiator}; - +use crate::{Initiator, PlatformKind, UserAgentKind}; #[derive(Debug, Default)] pub struct UserAgentDatabase { profiles: HashMap, http_profiles: HashMap, - + #[cfg(feature = "tls")] tls_profiles: HashMap, } @@ -109,11 +111,15 @@ impl Serialize for UserAgentFilter { if self.platform | PlatformKind::IOS as u8 != 0 { platforms.push(PlatformKind::IOS); } - + let filter = UserAgentFilterSerde { kind: if kinds.is_empty() { None } else { Some(kinds) }, - platform: if platforms.is_empty() { None } else { Some(platforms) }, - initiator: self.initiator.clone(), + platform: if platforms.is_empty() { + None + } else { + Some(platforms) + }, + initiator: self.initiator, }; filter.serialize(serializer) } @@ -146,33 +152,41 @@ impl<'de> Deserialize<'de> for UserAgentFilter { impl UserAgentDatabase { pub fn insert_http_profile(&mut self, profile: crate::UserAgentHttpProfile) { let key = profile.key(); - self.profiles.entry(UserAgentProfileKey { - ua_kind: profile.ua_kind, - ua_kind_version: profile.ua_kind_version, - platform_kind: profile.platform_kind, - }).or_insert_with(|| UserAgentProfile { - ua_kind: profile.ua_kind, - platform_kind: profile.platform_kind, - http_profiles: Vec::new(), - #[cfg(feature = "tls")] - tls_profiles: Vec::new(), - }).http_profiles.push(key); + self.profiles + .entry(UserAgentProfileKey { + ua_kind: profile.ua_kind, + ua_kind_version: profile.ua_kind_version, + platform_kind: profile.platform_kind, + }) + .or_insert_with(|| UserAgentProfile { + ua_kind: profile.ua_kind, + platform_kind: profile.platform_kind, + http_profiles: Vec::new(), + #[cfg(feature = "tls")] + tls_profiles: Vec::new(), + }) + .http_profiles + .push(key); self.http_profiles.insert(key, profile.http); } #[cfg(feature = "tls")] pub fn insert_tls_profile(&mut self, profile: crate::UserAgentTlsProfile) { let key = profile.key(); - self.profiles.entry(UserAgentProfileKey { - ua_kind: profile.ua_kind, - ua_kind_version: profile.ua_kind_version, - platform_kind: profile.platform_kind, - }).or_insert_with(|| UserAgentProfile { - ua_kind: profile.ua_kind, - platform_kind: profile.platform_kind, - http_profiles: Vec::new(), - tls_profiles: Vec::new(), - }).tls_profiles.push(key); + self.profiles + .entry(UserAgentProfileKey { + ua_kind: profile.ua_kind, + ua_kind_version: profile.ua_kind_version, + platform_kind: profile.platform_kind, + }) + .or_insert_with(|| UserAgentProfile { + ua_kind: profile.ua_kind, + platform_kind: profile.platform_kind, + http_profiles: Vec::new(), + tls_profiles: Vec::new(), + }) + .tls_profiles + .push(key); } pub fn query( @@ -195,41 +209,59 @@ impl UserAgentDatabase { } else { filter.platform }; - - let profiles: Vec<_> = self.profiles.values() + + let profiles: Vec<_> = self + .profiles + .values() .filter(|profile| profile.match_filters(kind_mask, platform_mask)) .collect(); if profiles.is_empty() { tracing::debug!(?filter, "no profiles found for provided filters"); return None; } else { - tracing::trace!(?filter, "found {} profile(s) for provided filters", profiles.len()); + tracing::trace!( + ?filter, + "found {} profile(s) for provided filters", + profiles.len() + ); } // market share from https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) - let weights: Vec = profiles.iter().map(|profiles| match profiles.ua_kind { - UserAgentKind::Firefox => 0.03, - UserAgentKind::Safari => 0.18, - UserAgentKind::Chromium => 0.79, - }).collect(); + let weights: Vec = profiles + .iter() + .map(|profiles| match profiles.ua_kind { + UserAgentKind::Firefox => 0.03, + UserAgentKind::Safari => 0.18, + UserAgentKind::Chromium => 0.79, + }) + .collect(); let dist = WeightedIndex::new(&weights).ok()?; let profile = profiles.get(dist.sample(&mut rng))?; // try to get random http profile with initiator if defined, else random http profile let http_profile_index = if let Some(initiator) = filter.initiator { - profile.http_profiles.iter().filter(|key| self.http_profiles.get(key).map(|http| http.initiator == initiator).unwrap_or(false)).choose(&mut rng) + profile + .http_profiles + .iter() + .filter(|key| { + self.http_profiles + .get(key) + .map(|http| http.initiator == initiator) + .unwrap_or(false) + }) + .choose(&mut rng) } else { profile.http_profiles.choose(&mut rng) }?; let http_profile = self.http_profiles.get(http_profile_index)?; - #[cfg(feature = "tls")] - profile.tls_profiles.choose(&mut rng).and_then(|key| { - self.tls_profiles.get(key) - })?; + let tls_profile = profile + .tls_profiles + .choose(&mut rng) + .and_then(|key| self.tls_profiles.get(key))?; - Some(UserAgentProfileQueryResult{ + Some(UserAgentProfileQueryResult { http: http_profile, #[cfg(feature = "tls")] tls: tls_profile, diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs index 5d026589..88dce8bd 100644 --- a/rama-ua/src/profile/http.rs +++ b/rama-ua/src/profile/http.rs @@ -2,7 +2,11 @@ use highway::HighwayHasher; use rama_core::error::OpaqueError; use rama_http_types::proto::h2::PseudoHeader; use serde::{Deserialize, Serialize}; -use std::{borrow::Cow, hash::{Hash as _, Hasher as _}, str::FromStr}; +use std::{ + borrow::Cow, + hash::{Hash as _, Hasher as _}, + str::FromStr, +}; use crate::{PlatformKind, UserAgentKind}; @@ -101,7 +105,8 @@ impl FromStr for HttpVersion { impl Serialize for HttpVersion { fn serialize(&self, serializer: S) -> Result where - S: serde::Serializer { + S: serde::Serializer, + { serializer.serialize_str(self.as_str()) } } @@ -109,9 +114,9 @@ impl Serialize for HttpVersion { impl<'de> Deserialize<'de> for HttpVersion { fn deserialize(deserializer: D) -> Result where - D: serde::Deserializer<'de> { + D: serde::Deserializer<'de>, + { let s = >::deserialize(deserializer)?; HttpVersion::from_str(&s).map_err(serde::de::Error::custom) } } - diff --git a/rama-ua/src/profile/tls.rs b/rama-ua/src/profile/tls.rs index 7c686e3b..bbb72687 100644 --- a/rama-ua/src/profile/tls.rs +++ b/rama-ua/src/profile/tls.rs @@ -1,8 +1,12 @@ +use std::hash::{Hash as _, Hasher as _}; + use rama_net::tls::client::ClientHello; use serde::{Deserialize, Serialize}; use highway::HighwayHasher; +use crate::{PlatformKind, UserAgentKind}; + #[derive(Debug, Clone, Serialize, Deserialize, Hash)] pub struct UserAgentTlsProfile { pub ua_kind: UserAgentKind, From 1307de146ee1f4ff4e48d0cd0cf3d0eb6556fc95 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Fri, 7 Feb 2025 09:13:15 +0100 Subject: [PATCH 06/39] making ua profiles a single vec for runtime usage --- rama-ua/src/profile/db.rs | 57 ++++++++++++++++++++++--------------- rama-ua/src/profile/http.rs | 8 ++++++ rama-ua/src/profile/tls.rs | 8 ++++++ 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index 6019804c..0c687c58 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -9,7 +9,8 @@ use crate::{Initiator, PlatformKind, UserAgentKind}; #[derive(Debug, Default)] pub struct UserAgentDatabase { - profiles: HashMap, + profile_keys: HashMap, + profiles: Vec, http_profiles: HashMap, @@ -151,42 +152,52 @@ impl<'de> Deserialize<'de> for UserAgentFilter { impl UserAgentDatabase { pub fn insert_http_profile(&mut self, profile: crate::UserAgentHttpProfile) { - let key = profile.key(); - self.profiles + let index = *self + .profile_keys .entry(UserAgentProfileKey { ua_kind: profile.ua_kind, ua_kind_version: profile.ua_kind_version, platform_kind: profile.platform_kind, }) - .or_insert_with(|| UserAgentProfile { - ua_kind: profile.ua_kind, - platform_kind: profile.platform_kind, - http_profiles: Vec::new(), - #[cfg(feature = "tls")] - tls_profiles: Vec::new(), - }) - .http_profiles - .push(key); + .or_insert_with(|| { + let idx = self.profiles.len(); + self.profiles.push(UserAgentProfile { + ua_kind: profile.ua_kind, + platform_kind: profile.platform_kind, + http_profiles: Vec::new(), + #[cfg(feature = "tls")] + tls_profiles: Vec::new(), + }); + idx + }); + let key = profile.http.key(); self.http_profiles.insert(key, profile.http); + self.profiles[index].http_profiles.push(key); } #[cfg(feature = "tls")] pub fn insert_tls_profile(&mut self, profile: crate::UserAgentTlsProfile) { - let key = profile.key(); - self.profiles + let index = *self + .profile_keys .entry(UserAgentProfileKey { ua_kind: profile.ua_kind, ua_kind_version: profile.ua_kind_version, platform_kind: profile.platform_kind, }) - .or_insert_with(|| UserAgentProfile { - ua_kind: profile.ua_kind, - platform_kind: profile.platform_kind, - http_profiles: Vec::new(), - tls_profiles: Vec::new(), - }) - .tls_profiles - .push(key); + .or_insert_with(|| { + let idx = self.profiles.len(); + self.profiles.push(UserAgentProfile { + ua_kind: profile.ua_kind, + platform_kind: profile.platform_kind, + http_profiles: Vec::new(), + #[cfg(feature = "tls")] + tls_profiles: Vec::new(), + }); + idx + }); + let key = profile.tls.key(); + self.tls_profiles.insert(key, profile.tls); + self.profiles[index].tls_profiles.push(key); } pub fn query( @@ -212,7 +223,7 @@ impl UserAgentDatabase { let profiles: Vec<_> = self .profiles - .values() + .iter() .filter(|profile| profile.match_filters(kind_mask, platform_mask)) .collect(); if profiles.is_empty() { diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs index 88dce8bd..e70f4d2c 100644 --- a/rama-ua/src/profile/http.rs +++ b/rama-ua/src/profile/http.rs @@ -35,6 +35,14 @@ pub struct HttpProfile { pub http_version: HttpVersion, } +impl HttpProfile { + pub fn key(&self) -> u64 { + let mut hasher = HighwayHasher::default(); + self.hash(&mut hasher); + hasher.finish() + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] pub enum Initiator { Navigator, diff --git a/rama-ua/src/profile/tls.rs b/rama-ua/src/profile/tls.rs index bbb72687..1aa80bf6 100644 --- a/rama-ua/src/profile/tls.rs +++ b/rama-ua/src/profile/tls.rs @@ -28,3 +28,11 @@ pub struct TlsProfile { pub ja4: String, pub client_hello: ClientHello, } + +impl TlsProfile { + pub fn key(&self) -> u64 { + let mut hasher = HighwayHasher::default(); + self.hash(&mut hasher); + hasher.finish() + } +} From 6bd8a29ff6d82982898c602f443c89e58b722520 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Mon, 10 Feb 2025 09:56:05 +0100 Subject: [PATCH 07/39] prepare rama-ua for new profile structure next up... db.. again again --- rama-http-types/src/proto/h1/headers/map.rs | 42 ++- rama-ua/src/profile/db.rs | 282 +------------------- rama-ua/src/profile/http.rs | 137 ++-------- rama-ua/src/profile/mod.rs | 3 + rama-ua/src/profile/tls.rs | 31 --- rama-ua/src/profile/todo_delete_db.rs | 281 +++++++++++++++++++ rama-ua/src/profile/ua.rs | 20 ++ 7 files changed, 363 insertions(+), 433 deletions(-) create mode 100644 rama-ua/src/profile/todo_delete_db.rs create mode 100644 rama-ua/src/profile/ua.rs diff --git a/rama-http-types/src/proto/h1/headers/map.rs b/rama-http-types/src/proto/h1/headers/map.rs index 409641ea..c9f2469d 100644 --- a/rama-http-types/src/proto/h1/headers/map.rs +++ b/rama-http-types/src/proto/h1/headers/map.rs @@ -1,4 +1,9 @@ -use std::collections::{self, HashMap}; +use std::{ + borrow::Cow, + collections::{self, HashMap}, +}; + +use serde::{de::Error as _, ser::Error as _, Deserialize, Serialize}; use super::{ name::{IntoHttp1HeaderName, IntoSealed as _, TryIntoHttp1HeaderName}, @@ -132,6 +137,41 @@ impl IntoIterator for Http1HeaderMap { } } +impl Serialize for Http1HeaderMap { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let headers: Result, _> = self + .clone() + .into_iter() + .map(|(name, value)| { + let value = value.to_str().map_err(S::Error::custom)?; + Ok::<_, S::Error>((name, value.to_owned())) + }) + .collect(); + headers?.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Http1HeaderMap { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let headers = )>>::deserialize(deserializer)?; + headers + .into_iter() + .map(|(name, value)| { + Ok::<_, D::Error>(( + name, + HeaderValue::from_str(&value).map_err(D::Error::custom)?, + )) + }) + .collect() + } +} + #[derive(Debug)] pub struct Http1HeaderMapIntoIter { state: Http1HeaderMapIntoIterState, diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index 0c687c58..d48b9f75 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -1,281 +1 @@ -use rand::{ - distr::{weighted::WeightedIndex, Distribution as _}, - seq::{IndexedRandom as _, IteratorRandom as _}, -}; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -use crate::{Initiator, PlatformKind, UserAgentKind}; - -#[derive(Debug, Default)] -pub struct UserAgentDatabase { - profile_keys: HashMap, - profiles: Vec, - - http_profiles: HashMap, - - #[cfg(feature = "tls")] - tls_profiles: HashMap, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -struct UserAgentProfileKey { - pub ua_kind: UserAgentKind, - pub ua_kind_version: usize, - pub platform_kind: PlatformKind, -} - -#[derive(Debug)] -struct UserAgentProfile { - pub ua_kind: UserAgentKind, - pub platform_kind: PlatformKind, - pub http_profiles: Vec, - - #[cfg(feature = "tls")] - pub tls_profiles: Vec, -} - -impl UserAgentProfile { - fn match_filters(&self, kind_mask: u8, platform_mask: u8) -> bool { - if self.http_profiles.is_empty() { - return false; - } - - #[cfg(feature = "tls")] - if self.tls_profiles.is_empty() { - return false; - } - - self.ua_kind as u8 & kind_mask != 0 && self.platform_kind as u8 & platform_mask != 0 - } -} - -impl UserAgentDatabase { - /// Create a new user agent database. - #[inline] - pub fn new() -> Self { - Self::default() - } -} - -#[derive(Debug, Clone, Default)] -pub struct UserAgentFilter { - pub kind: u8, - pub platform: u8, - pub initiator: Option, -} - -#[derive(Debug, Clone)] -pub struct UserAgentProfileQueryResult<'a> { - pub http: &'a crate::HttpProfile, - - #[cfg(feature = "tls")] - pub tls: &'a crate::TlsProfile, -} - -#[derive(Serialize, Deserialize)] -struct UserAgentFilterSerde { - kind: Option>, - platform: Option>, - initiator: Option, -} - -impl Serialize for UserAgentFilter { - fn serialize(&self, serializer: S) -> Result - where - S: serde::ser::Serializer, - { - let mut kinds = Vec::new(); - if self.kind | UserAgentKind::Chromium as u8 != 0 { - kinds.push(UserAgentKind::Chromium); - } - if self.kind | UserAgentKind::Firefox as u8 != 0 { - kinds.push(UserAgentKind::Firefox); - } - if self.kind | UserAgentKind::Safari as u8 != 0 { - kinds.push(UserAgentKind::Safari); - } - - let mut platforms = Vec::new(); - if self.platform | PlatformKind::Windows as u8 != 0 { - platforms.push(PlatformKind::Windows); - } - if self.platform | PlatformKind::MacOS as u8 != 0 { - platforms.push(PlatformKind::MacOS); - } - if self.platform | PlatformKind::Linux as u8 != 0 { - platforms.push(PlatformKind::Linux); - } - if self.platform | PlatformKind::Android as u8 != 0 { - platforms.push(PlatformKind::Android); - } - if self.platform | PlatformKind::IOS as u8 != 0 { - platforms.push(PlatformKind::IOS); - } - - let filter = UserAgentFilterSerde { - kind: if kinds.is_empty() { None } else { Some(kinds) }, - platform: if platforms.is_empty() { - None - } else { - Some(platforms) - }, - initiator: self.initiator, - }; - filter.serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for UserAgentFilter { - fn deserialize(deserializer: D) -> Result - where - D: serde::de::Deserializer<'de>, - { - let filter = UserAgentFilterSerde::deserialize(deserializer)?; - let mut result = UserAgentFilter::default(); - if let Some(kinds) = filter.kind { - for kind in kinds { - result.kind |= kind as u8; - } - } - if let Some(platforms) = filter.platform { - for platform in platforms { - result.platform |= platform as u8; - } - } - if let Some(initiator) = filter.initiator { - result.initiator = Some(initiator); - } - Ok(result) - } -} - -impl UserAgentDatabase { - pub fn insert_http_profile(&mut self, profile: crate::UserAgentHttpProfile) { - let index = *self - .profile_keys - .entry(UserAgentProfileKey { - ua_kind: profile.ua_kind, - ua_kind_version: profile.ua_kind_version, - platform_kind: profile.platform_kind, - }) - .or_insert_with(|| { - let idx = self.profiles.len(); - self.profiles.push(UserAgentProfile { - ua_kind: profile.ua_kind, - platform_kind: profile.platform_kind, - http_profiles: Vec::new(), - #[cfg(feature = "tls")] - tls_profiles: Vec::new(), - }); - idx - }); - let key = profile.http.key(); - self.http_profiles.insert(key, profile.http); - self.profiles[index].http_profiles.push(key); - } - - #[cfg(feature = "tls")] - pub fn insert_tls_profile(&mut self, profile: crate::UserAgentTlsProfile) { - let index = *self - .profile_keys - .entry(UserAgentProfileKey { - ua_kind: profile.ua_kind, - ua_kind_version: profile.ua_kind_version, - platform_kind: profile.platform_kind, - }) - .or_insert_with(|| { - let idx = self.profiles.len(); - self.profiles.push(UserAgentProfile { - ua_kind: profile.ua_kind, - platform_kind: profile.platform_kind, - http_profiles: Vec::new(), - #[cfg(feature = "tls")] - tls_profiles: Vec::new(), - }); - idx - }); - let key = profile.tls.key(); - self.tls_profiles.insert(key, profile.tls); - self.profiles[index].tls_profiles.push(key); - } - - pub fn query( - &self, - filters: Option, - ) -> Option> { - let filter = filters.unwrap_or_default(); - let mut rng = rand::rng(); - - let kind_mask = if filter.kind == 0 { - tracing::trace!("no kind filter provided, using all"); - u8::MAX - } else { - filter.kind - }; - - let platform_mask = if filter.platform == 0 { - tracing::trace!("no platform filter provided, using all"); - u8::MAX - } else { - filter.platform - }; - - let profiles: Vec<_> = self - .profiles - .iter() - .filter(|profile| profile.match_filters(kind_mask, platform_mask)) - .collect(); - if profiles.is_empty() { - tracing::debug!(?filter, "no profiles found for provided filters"); - return None; - } else { - tracing::trace!( - ?filter, - "found {} profile(s) for provided filters", - profiles.len() - ); - } - - // market share from https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) - let weights: Vec = profiles - .iter() - .map(|profiles| match profiles.ua_kind { - UserAgentKind::Firefox => 0.03, - UserAgentKind::Safari => 0.18, - UserAgentKind::Chromium => 0.79, - }) - .collect(); - let dist = WeightedIndex::new(&weights).ok()?; - let profile = profiles.get(dist.sample(&mut rng))?; - - // try to get random http profile with initiator if defined, else random http profile - let http_profile_index = if let Some(initiator) = filter.initiator { - profile - .http_profiles - .iter() - .filter(|key| { - self.http_profiles - .get(key) - .map(|http| http.initiator == initiator) - .unwrap_or(false) - }) - .choose(&mut rng) - } else { - profile.http_profiles.choose(&mut rng) - }?; - let http_profile = self.http_profiles.get(http_profile_index)?; - - #[cfg(feature = "tls")] - let tls_profile = profile - .tls_profiles - .choose(&mut rng) - .and_then(|key| self.tls_profiles.get(key))?; - - Some(UserAgentProfileQueryResult { - http: http_profile, - #[cfg(feature = "tls")] - tls: tls_profile, - }) - } -} +pub const TODO: usize = 42; diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs index e70f4d2c..67873f90 100644 --- a/rama-ua/src/profile/http.rs +++ b/rama-ua/src/profile/http.rs @@ -1,130 +1,27 @@ -use highway::HighwayHasher; -use rama_core::error::OpaqueError; -use rama_http_types::proto::h2::PseudoHeader; +use rama_http_types::proto::{h1::Http1HeaderMap, h2::PseudoHeader}; use serde::{Deserialize, Serialize}; -use std::{ - borrow::Cow, - hash::{Hash as _, Hasher as _}, - str::FromStr, -}; -use crate::{PlatformKind, UserAgentKind}; - -#[derive(Debug, Clone, Serialize, Deserialize, Hash)] -pub struct UserAgentHttpProfile { - pub ua_kind: UserAgentKind, - pub ua_kind_version: usize, - pub platform_kind: PlatformKind, - pub http: HttpProfile, -} - -impl UserAgentHttpProfile { - pub fn key(&self) -> u64 { - let mut hasher = HighwayHasher::default(); - self.hash(&mut hasher); - hasher.finish() - } -} - -#[derive(Debug, Clone, Deserialize, Serialize, Hash)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct HttpProfile { - pub ja4h: String, - pub http_headers: Vec<(String, String)>, - pub http_pseudo_headers: Vec, - pub initiator: Initiator, - pub http_version: HttpVersion, + pub headers: HttpHeadersProfile, + pub h1: Http1Profile, + pub h2: Http2Profile, } -impl HttpProfile { - pub fn key(&self) -> u64 { - let mut hasher = HighwayHasher::default(); - self.hash(&mut hasher); - hasher.finish() - } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct HttpHeadersProfile { + pub navigate: Http1HeaderMap, + pub fetch: Http1HeaderMap, + pub xhr: Http1HeaderMap, + pub form: Http1HeaderMap, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] -pub enum Initiator { - Navigator, - Fetch, - XMLHttpRequest, - Form, -} - -impl std::fmt::Display for Initiator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Navigator => write!(f, "navigator"), - Self::Fetch => write!(f, "fetch"), - Self::XMLHttpRequest => write!(f, "xmlhttprequest"), - Self::Form => write!(f, "form"), - } - } -} - -impl FromStr for Initiator { - type Err = String; - - fn from_str(s: &str) -> Result { - match s { - "navigator" => Ok(Self::Navigator), - "fetch" => Ok(Self::Fetch), - "xmlhttprequest" => Ok(Self::XMLHttpRequest), - "form" => Ok(Self::Form), - _ => Err(s.to_owned()), - } - } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Http1Profile { + pub title_case_headers: bool, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum HttpVersion { - H1, - H2, - H3, -} - -impl HttpVersion { - pub fn as_str(&self) -> &'static str { - match self { - Self::H1 => "http/1", - Self::H2 => "h2", - Self::H3 => "h3", - } - } -} - -impl FromStr for HttpVersion { - type Err = OpaqueError; - - fn from_str(s: &str) -> Result { - Ok(match s.trim().to_lowercase().as_str() { - "h1" | "http1" | "http/1" | "http/1.0" | "http/1.1" => Self::H1, - "h2" | "http2" | "http/2" | "http/2.0" => Self::H2, - "h3" | "http3" | "http/3" | "http/3.0" => Self::H3, - version => { - return Err(OpaqueError::from_display(format!( - "unsupported http version: {version}" - ))) - } - }) - } -} - -impl Serialize for HttpVersion { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(self.as_str()) - } -} - -impl<'de> Deserialize<'de> for HttpVersion { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s = >::deserialize(deserializer)?; - HttpVersion::from_str(&s).map_err(serde::de::Error::custom) - } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Http2Profile { + pub http_pseudo_headers: Vec, } diff --git a/rama-ua/src/profile/mod.rs b/rama-ua/src/profile/mod.rs index f3b07e5d..662e0455 100644 --- a/rama-ua/src/profile/mod.rs +++ b/rama-ua/src/profile/mod.rs @@ -6,5 +6,8 @@ mod tls; #[cfg(feature = "tls")] pub use tls::*; +mod ua; +pub use ua::*; + mod db; pub use db::*; diff --git a/rama-ua/src/profile/tls.rs b/rama-ua/src/profile/tls.rs index 1aa80bf6..1d88ba69 100644 --- a/rama-ua/src/profile/tls.rs +++ b/rama-ua/src/profile/tls.rs @@ -1,38 +1,7 @@ -use std::hash::{Hash as _, Hasher as _}; - use rama_net::tls::client::ClientHello; use serde::{Deserialize, Serialize}; -use highway::HighwayHasher; - -use crate::{PlatformKind, UserAgentKind}; - -#[derive(Debug, Clone, Serialize, Deserialize, Hash)] -pub struct UserAgentTlsProfile { - pub ua_kind: UserAgentKind, - pub ua_kind_version: usize, - pub platform_kind: PlatformKind, - pub tls: TlsProfile, -} - -impl UserAgentTlsProfile { - pub fn key(&self) -> u64 { - let mut hasher = HighwayHasher::default(); - self.hash(&mut hasher); - hasher.finish() - } -} - #[derive(Debug, Clone, Deserialize, Serialize, Hash)] pub struct TlsProfile { - pub ja4: String, pub client_hello: ClientHello, } - -impl TlsProfile { - pub fn key(&self) -> u64 { - let mut hasher = HighwayHasher::default(); - self.hash(&mut hasher); - hasher.finish() - } -} diff --git a/rama-ua/src/profile/todo_delete_db.rs b/rama-ua/src/profile/todo_delete_db.rs new file mode 100644 index 00000000..0c687c58 --- /dev/null +++ b/rama-ua/src/profile/todo_delete_db.rs @@ -0,0 +1,281 @@ +use rand::{ + distr::{weighted::WeightedIndex, Distribution as _}, + seq::{IndexedRandom as _, IteratorRandom as _}, +}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::{Initiator, PlatformKind, UserAgentKind}; + +#[derive(Debug, Default)] +pub struct UserAgentDatabase { + profile_keys: HashMap, + profiles: Vec, + + http_profiles: HashMap, + + #[cfg(feature = "tls")] + tls_profiles: HashMap, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct UserAgentProfileKey { + pub ua_kind: UserAgentKind, + pub ua_kind_version: usize, + pub platform_kind: PlatformKind, +} + +#[derive(Debug)] +struct UserAgentProfile { + pub ua_kind: UserAgentKind, + pub platform_kind: PlatformKind, + pub http_profiles: Vec, + + #[cfg(feature = "tls")] + pub tls_profiles: Vec, +} + +impl UserAgentProfile { + fn match_filters(&self, kind_mask: u8, platform_mask: u8) -> bool { + if self.http_profiles.is_empty() { + return false; + } + + #[cfg(feature = "tls")] + if self.tls_profiles.is_empty() { + return false; + } + + self.ua_kind as u8 & kind_mask != 0 && self.platform_kind as u8 & platform_mask != 0 + } +} + +impl UserAgentDatabase { + /// Create a new user agent database. + #[inline] + pub fn new() -> Self { + Self::default() + } +} + +#[derive(Debug, Clone, Default)] +pub struct UserAgentFilter { + pub kind: u8, + pub platform: u8, + pub initiator: Option, +} + +#[derive(Debug, Clone)] +pub struct UserAgentProfileQueryResult<'a> { + pub http: &'a crate::HttpProfile, + + #[cfg(feature = "tls")] + pub tls: &'a crate::TlsProfile, +} + +#[derive(Serialize, Deserialize)] +struct UserAgentFilterSerde { + kind: Option>, + platform: Option>, + initiator: Option, +} + +impl Serialize for UserAgentFilter { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + let mut kinds = Vec::new(); + if self.kind | UserAgentKind::Chromium as u8 != 0 { + kinds.push(UserAgentKind::Chromium); + } + if self.kind | UserAgentKind::Firefox as u8 != 0 { + kinds.push(UserAgentKind::Firefox); + } + if self.kind | UserAgentKind::Safari as u8 != 0 { + kinds.push(UserAgentKind::Safari); + } + + let mut platforms = Vec::new(); + if self.platform | PlatformKind::Windows as u8 != 0 { + platforms.push(PlatformKind::Windows); + } + if self.platform | PlatformKind::MacOS as u8 != 0 { + platforms.push(PlatformKind::MacOS); + } + if self.platform | PlatformKind::Linux as u8 != 0 { + platforms.push(PlatformKind::Linux); + } + if self.platform | PlatformKind::Android as u8 != 0 { + platforms.push(PlatformKind::Android); + } + if self.platform | PlatformKind::IOS as u8 != 0 { + platforms.push(PlatformKind::IOS); + } + + let filter = UserAgentFilterSerde { + kind: if kinds.is_empty() { None } else { Some(kinds) }, + platform: if platforms.is_empty() { + None + } else { + Some(platforms) + }, + initiator: self.initiator, + }; + filter.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for UserAgentFilter { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + let filter = UserAgentFilterSerde::deserialize(deserializer)?; + let mut result = UserAgentFilter::default(); + if let Some(kinds) = filter.kind { + for kind in kinds { + result.kind |= kind as u8; + } + } + if let Some(platforms) = filter.platform { + for platform in platforms { + result.platform |= platform as u8; + } + } + if let Some(initiator) = filter.initiator { + result.initiator = Some(initiator); + } + Ok(result) + } +} + +impl UserAgentDatabase { + pub fn insert_http_profile(&mut self, profile: crate::UserAgentHttpProfile) { + let index = *self + .profile_keys + .entry(UserAgentProfileKey { + ua_kind: profile.ua_kind, + ua_kind_version: profile.ua_kind_version, + platform_kind: profile.platform_kind, + }) + .or_insert_with(|| { + let idx = self.profiles.len(); + self.profiles.push(UserAgentProfile { + ua_kind: profile.ua_kind, + platform_kind: profile.platform_kind, + http_profiles: Vec::new(), + #[cfg(feature = "tls")] + tls_profiles: Vec::new(), + }); + idx + }); + let key = profile.http.key(); + self.http_profiles.insert(key, profile.http); + self.profiles[index].http_profiles.push(key); + } + + #[cfg(feature = "tls")] + pub fn insert_tls_profile(&mut self, profile: crate::UserAgentTlsProfile) { + let index = *self + .profile_keys + .entry(UserAgentProfileKey { + ua_kind: profile.ua_kind, + ua_kind_version: profile.ua_kind_version, + platform_kind: profile.platform_kind, + }) + .or_insert_with(|| { + let idx = self.profiles.len(); + self.profiles.push(UserAgentProfile { + ua_kind: profile.ua_kind, + platform_kind: profile.platform_kind, + http_profiles: Vec::new(), + #[cfg(feature = "tls")] + tls_profiles: Vec::new(), + }); + idx + }); + let key = profile.tls.key(); + self.tls_profiles.insert(key, profile.tls); + self.profiles[index].tls_profiles.push(key); + } + + pub fn query( + &self, + filters: Option, + ) -> Option> { + let filter = filters.unwrap_or_default(); + let mut rng = rand::rng(); + + let kind_mask = if filter.kind == 0 { + tracing::trace!("no kind filter provided, using all"); + u8::MAX + } else { + filter.kind + }; + + let platform_mask = if filter.platform == 0 { + tracing::trace!("no platform filter provided, using all"); + u8::MAX + } else { + filter.platform + }; + + let profiles: Vec<_> = self + .profiles + .iter() + .filter(|profile| profile.match_filters(kind_mask, platform_mask)) + .collect(); + if profiles.is_empty() { + tracing::debug!(?filter, "no profiles found for provided filters"); + return None; + } else { + tracing::trace!( + ?filter, + "found {} profile(s) for provided filters", + profiles.len() + ); + } + + // market share from https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) + let weights: Vec = profiles + .iter() + .map(|profiles| match profiles.ua_kind { + UserAgentKind::Firefox => 0.03, + UserAgentKind::Safari => 0.18, + UserAgentKind::Chromium => 0.79, + }) + .collect(); + let dist = WeightedIndex::new(&weights).ok()?; + let profile = profiles.get(dist.sample(&mut rng))?; + + // try to get random http profile with initiator if defined, else random http profile + let http_profile_index = if let Some(initiator) = filter.initiator { + profile + .http_profiles + .iter() + .filter(|key| { + self.http_profiles + .get(key) + .map(|http| http.initiator == initiator) + .unwrap_or(false) + }) + .choose(&mut rng) + } else { + profile.http_profiles.choose(&mut rng) + }?; + let http_profile = self.http_profiles.get(http_profile_index)?; + + #[cfg(feature = "tls")] + let tls_profile = profile + .tls_profiles + .choose(&mut rng) + .and_then(|key| self.tls_profiles.get(key))?; + + Some(UserAgentProfileQueryResult { + http: http_profile, + #[cfg(feature = "tls")] + tls: tls_profile, + }) + } +} diff --git a/rama-ua/src/profile/ua.rs b/rama-ua/src/profile/ua.rs new file mode 100644 index 00000000..df3b8ffc --- /dev/null +++ b/rama-ua/src/profile/ua.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; + +use crate::{PlatformKind, UserAgentKind}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct UserAgentProfile { + /// The kind of [`crate::UserAgent`] + pub ua_kind: UserAgentKind, + /// The version of the [`crate::UserAgent`] + pub ua_version: Option, + /// The platform the [`crate::UserAgent`] is running on. + pub platform: Option, + + /// The profile information regarding the http implementation of the [`crate::UserAgent`]. + pub http: super::HttpProfile, + + #[cfg(feature = "tls")] + /// The profile information regarding the tls implementation of the [`crate::UserAgent`]. + pub tls: super::TlsProfile, +} From 49e7323f691de453de25191363d3d8804d8611d1 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Mon, 10 Feb 2025 11:25:33 +0100 Subject: [PATCH 08/39] impl db more compatible with other UA layers --- rama-http-types/src/proto/h1/headers/map.rs | 6 + rama-ua/src/profile/db.rs | 118 +++++++- rama-ua/src/profile/http.rs | 6 +- rama-ua/src/profile/todo_delete_db.rs | 281 -------------------- rama-ua/src/profile/ua.rs | 11 + rama-ua/src/ua/info.rs | 83 +++--- 6 files changed, 185 insertions(+), 320 deletions(-) delete mode 100644 rama-ua/src/profile/todo_delete_db.rs diff --git a/rama-http-types/src/proto/h1/headers/map.rs b/rama-http-types/src/proto/h1/headers/map.rs index c9f2469d..5945a7b4 100644 --- a/rama-http-types/src/proto/h1/headers/map.rs +++ b/rama-http-types/src/proto/h1/headers/map.rs @@ -3,6 +3,7 @@ use std::{ collections::{self, HashMap}, }; +use http::header::AsHeaderName; use serde::{de::Error as _, ser::Error as _, Deserialize, Serialize}; use super::{ @@ -55,6 +56,11 @@ impl Http1HeaderMap { } } + #[inline] + pub fn get(&self, key: impl AsHeaderName) -> Option<&HeaderValue> { + self.headers.get(key) + } + pub fn into_headers(self) -> HeaderMap { self.headers } diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index d48b9f75..9cdde4f4 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -1 +1,117 @@ -pub const TODO: usize = 42; +use rand::seq::IndexedRandom as _; +use std::collections::HashMap; + +use crate::{DeviceKind, PlatformKind, UserAgent, UserAgentKind, UserAgentProfile}; + +#[derive(Debug, Default)] +pub struct UserAgentDatabase { + profiles: Vec, + + map_ua_string: HashMap, + + map_platform: HashMap<(UserAgentKind, PlatformKind), Vec>, + map_device: HashMap<(UserAgentKind, DeviceKind), Vec>, +} + +impl UserAgentDatabase { + pub fn insert(&mut self, profile: UserAgentProfile) { + let index = self.profiles.len(); + if let Some(ua_header) = profile.ua_str() { + self.map_ua_string.insert(ua_header.to_string(), index); + } + + if let Some(platform) = profile.platform { + self.map_platform + .entry((profile.ua_kind, platform)) + .or_default() + .push(index); + self.map_device + .entry((profile.ua_kind, platform.device())) + .or_default() + .push(index); + } + + self.profiles.push(profile); + } + + pub fn get(&self, ua: &UserAgent) -> Option<&UserAgentProfile> { + if let Some(profile) = self + .map_ua_string + .get(ua.header_str()) + .and_then(|idx| self.profiles.get(*idx)) + { + return Some(profile); + } + + match ua.ua_kind() { + Some(ua_kind) => match ua.platform() { + // UA + Platform Match (e.g. chrome windows) + Some(platform) => self + .map_platform + .get(&(ua_kind, platform)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)), + // UA + Device match (e.g. firefox desktop) + None => { + let device = ua.device(); + self.map_device + .get(&(ua_kind, device)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + }, + // Market-share Kind + Device match (e.g. chrome desktop) + None => { + let device = ua.device(); + + // market share from + // https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) + let r = rand::random_range(0..=100); + let ua_kind = if r < 3 + && self + .map_device + .contains_key(&(UserAgentKind::Firefox, device)) + { + UserAgentKind::Firefox + } else if r < 18 + && self + .map_device + .contains_key(&(UserAgentKind::Safari, device)) + { + UserAgentKind::Safari + } else { + // ~79% x-x + UserAgentKind::Chromium + }; + + self.map_device + .get(&(ua_kind, device)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + } + } + + #[inline] + pub fn iter(&self) -> impl Iterator { + self.profiles.iter() + } +} + +impl FromIterator for UserAgentDatabase { + fn from_iter>(iter: T) -> Self { + let iter = iter.into_iter(); + let (lb, _) = iter.size_hint(); + + let mut db = UserAgentDatabase { + profiles: Vec::with_capacity(lb), + ..Default::default() + }; + + for profile in iter { + db.insert(profile); + } + + db + } +} diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs index 67873f90..9adb82a4 100644 --- a/rama-ua/src/profile/http.rs +++ b/rama-ua/src/profile/http.rs @@ -11,9 +11,9 @@ pub struct HttpProfile { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct HttpHeadersProfile { pub navigate: Http1HeaderMap, - pub fetch: Http1HeaderMap, - pub xhr: Http1HeaderMap, - pub form: Http1HeaderMap, + pub fetch: Option, + pub xhr: Option, + pub form: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/rama-ua/src/profile/todo_delete_db.rs b/rama-ua/src/profile/todo_delete_db.rs deleted file mode 100644 index 0c687c58..00000000 --- a/rama-ua/src/profile/todo_delete_db.rs +++ /dev/null @@ -1,281 +0,0 @@ -use rand::{ - distr::{weighted::WeightedIndex, Distribution as _}, - seq::{IndexedRandom as _, IteratorRandom as _}, -}; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -use crate::{Initiator, PlatformKind, UserAgentKind}; - -#[derive(Debug, Default)] -pub struct UserAgentDatabase { - profile_keys: HashMap, - profiles: Vec, - - http_profiles: HashMap, - - #[cfg(feature = "tls")] - tls_profiles: HashMap, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -struct UserAgentProfileKey { - pub ua_kind: UserAgentKind, - pub ua_kind_version: usize, - pub platform_kind: PlatformKind, -} - -#[derive(Debug)] -struct UserAgentProfile { - pub ua_kind: UserAgentKind, - pub platform_kind: PlatformKind, - pub http_profiles: Vec, - - #[cfg(feature = "tls")] - pub tls_profiles: Vec, -} - -impl UserAgentProfile { - fn match_filters(&self, kind_mask: u8, platform_mask: u8) -> bool { - if self.http_profiles.is_empty() { - return false; - } - - #[cfg(feature = "tls")] - if self.tls_profiles.is_empty() { - return false; - } - - self.ua_kind as u8 & kind_mask != 0 && self.platform_kind as u8 & platform_mask != 0 - } -} - -impl UserAgentDatabase { - /// Create a new user agent database. - #[inline] - pub fn new() -> Self { - Self::default() - } -} - -#[derive(Debug, Clone, Default)] -pub struct UserAgentFilter { - pub kind: u8, - pub platform: u8, - pub initiator: Option, -} - -#[derive(Debug, Clone)] -pub struct UserAgentProfileQueryResult<'a> { - pub http: &'a crate::HttpProfile, - - #[cfg(feature = "tls")] - pub tls: &'a crate::TlsProfile, -} - -#[derive(Serialize, Deserialize)] -struct UserAgentFilterSerde { - kind: Option>, - platform: Option>, - initiator: Option, -} - -impl Serialize for UserAgentFilter { - fn serialize(&self, serializer: S) -> Result - where - S: serde::ser::Serializer, - { - let mut kinds = Vec::new(); - if self.kind | UserAgentKind::Chromium as u8 != 0 { - kinds.push(UserAgentKind::Chromium); - } - if self.kind | UserAgentKind::Firefox as u8 != 0 { - kinds.push(UserAgentKind::Firefox); - } - if self.kind | UserAgentKind::Safari as u8 != 0 { - kinds.push(UserAgentKind::Safari); - } - - let mut platforms = Vec::new(); - if self.platform | PlatformKind::Windows as u8 != 0 { - platforms.push(PlatformKind::Windows); - } - if self.platform | PlatformKind::MacOS as u8 != 0 { - platforms.push(PlatformKind::MacOS); - } - if self.platform | PlatformKind::Linux as u8 != 0 { - platforms.push(PlatformKind::Linux); - } - if self.platform | PlatformKind::Android as u8 != 0 { - platforms.push(PlatformKind::Android); - } - if self.platform | PlatformKind::IOS as u8 != 0 { - platforms.push(PlatformKind::IOS); - } - - let filter = UserAgentFilterSerde { - kind: if kinds.is_empty() { None } else { Some(kinds) }, - platform: if platforms.is_empty() { - None - } else { - Some(platforms) - }, - initiator: self.initiator, - }; - filter.serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for UserAgentFilter { - fn deserialize(deserializer: D) -> Result - where - D: serde::de::Deserializer<'de>, - { - let filter = UserAgentFilterSerde::deserialize(deserializer)?; - let mut result = UserAgentFilter::default(); - if let Some(kinds) = filter.kind { - for kind in kinds { - result.kind |= kind as u8; - } - } - if let Some(platforms) = filter.platform { - for platform in platforms { - result.platform |= platform as u8; - } - } - if let Some(initiator) = filter.initiator { - result.initiator = Some(initiator); - } - Ok(result) - } -} - -impl UserAgentDatabase { - pub fn insert_http_profile(&mut self, profile: crate::UserAgentHttpProfile) { - let index = *self - .profile_keys - .entry(UserAgentProfileKey { - ua_kind: profile.ua_kind, - ua_kind_version: profile.ua_kind_version, - platform_kind: profile.platform_kind, - }) - .or_insert_with(|| { - let idx = self.profiles.len(); - self.profiles.push(UserAgentProfile { - ua_kind: profile.ua_kind, - platform_kind: profile.platform_kind, - http_profiles: Vec::new(), - #[cfg(feature = "tls")] - tls_profiles: Vec::new(), - }); - idx - }); - let key = profile.http.key(); - self.http_profiles.insert(key, profile.http); - self.profiles[index].http_profiles.push(key); - } - - #[cfg(feature = "tls")] - pub fn insert_tls_profile(&mut self, profile: crate::UserAgentTlsProfile) { - let index = *self - .profile_keys - .entry(UserAgentProfileKey { - ua_kind: profile.ua_kind, - ua_kind_version: profile.ua_kind_version, - platform_kind: profile.platform_kind, - }) - .or_insert_with(|| { - let idx = self.profiles.len(); - self.profiles.push(UserAgentProfile { - ua_kind: profile.ua_kind, - platform_kind: profile.platform_kind, - http_profiles: Vec::new(), - #[cfg(feature = "tls")] - tls_profiles: Vec::new(), - }); - idx - }); - let key = profile.tls.key(); - self.tls_profiles.insert(key, profile.tls); - self.profiles[index].tls_profiles.push(key); - } - - pub fn query( - &self, - filters: Option, - ) -> Option> { - let filter = filters.unwrap_or_default(); - let mut rng = rand::rng(); - - let kind_mask = if filter.kind == 0 { - tracing::trace!("no kind filter provided, using all"); - u8::MAX - } else { - filter.kind - }; - - let platform_mask = if filter.platform == 0 { - tracing::trace!("no platform filter provided, using all"); - u8::MAX - } else { - filter.platform - }; - - let profiles: Vec<_> = self - .profiles - .iter() - .filter(|profile| profile.match_filters(kind_mask, platform_mask)) - .collect(); - if profiles.is_empty() { - tracing::debug!(?filter, "no profiles found for provided filters"); - return None; - } else { - tracing::trace!( - ?filter, - "found {} profile(s) for provided filters", - profiles.len() - ); - } - - // market share from https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) - let weights: Vec = profiles - .iter() - .map(|profiles| match profiles.ua_kind { - UserAgentKind::Firefox => 0.03, - UserAgentKind::Safari => 0.18, - UserAgentKind::Chromium => 0.79, - }) - .collect(); - let dist = WeightedIndex::new(&weights).ok()?; - let profile = profiles.get(dist.sample(&mut rng))?; - - // try to get random http profile with initiator if defined, else random http profile - let http_profile_index = if let Some(initiator) = filter.initiator { - profile - .http_profiles - .iter() - .filter(|key| { - self.http_profiles - .get(key) - .map(|http| http.initiator == initiator) - .unwrap_or(false) - }) - .choose(&mut rng) - } else { - profile.http_profiles.choose(&mut rng) - }?; - let http_profile = self.http_profiles.get(http_profile_index)?; - - #[cfg(feature = "tls")] - let tls_profile = profile - .tls_profiles - .choose(&mut rng) - .and_then(|key| self.tls_profiles.get(key))?; - - Some(UserAgentProfileQueryResult { - http: http_profile, - #[cfg(feature = "tls")] - tls: tls_profile, - }) - } -} diff --git a/rama-ua/src/profile/ua.rs b/rama-ua/src/profile/ua.rs index df3b8ffc..c53859ac 100644 --- a/rama-ua/src/profile/ua.rs +++ b/rama-ua/src/profile/ua.rs @@ -1,3 +1,4 @@ +use rama_http_types::header::USER_AGENT; use serde::{Deserialize, Serialize}; use crate::{PlatformKind, UserAgentKind}; @@ -18,3 +19,13 @@ pub struct UserAgentProfile { /// The profile information regarding the tls implementation of the [`crate::UserAgent`]. pub tls: super::TlsProfile, } + +impl UserAgentProfile { + pub fn ua_str(&self) -> Option<&str> { + self.http + .headers + .navigate + .get(USER_AGENT) + .and_then(|v| v.to_str().ok()) + } +} diff --git a/rama-ua/src/ua/info.rs b/rama-ua/src/ua/info.rs index 27de84a6..fa5e227b 100644 --- a/rama-ua/src/ua/info.rs +++ b/rama-ua/src/ua/info.rs @@ -84,18 +84,10 @@ impl UserAgent { /// returns the device kind of the [`UserAgent`]. pub fn device(&self) -> DeviceKind { match &self.data { - UserAgentData::Standard { platform, .. } => match platform { - Some(PlatformKind::Windows | PlatformKind::MacOS | PlatformKind::Linux) | None => { - DeviceKind::Desktop - } - Some(PlatformKind::Android | PlatformKind::IOS) => DeviceKind::Mobile, - }, - UserAgentData::Platform(platform) => match platform { - PlatformKind::Windows | PlatformKind::MacOS | PlatformKind::Linux => { - DeviceKind::Desktop - } - PlatformKind::Android | PlatformKind::IOS => DeviceKind::Mobile, - }, + UserAgentData::Standard { platform, .. } => { + platform.map(|p| p.device()).unwrap_or(DeviceKind::Desktop) + } + UserAgentData::Platform(platform) => platform.device(), UserAgentData::Device(kind) => *kind, UserAgentData::Unknown => DeviceKind::Desktop, } @@ -111,6 +103,23 @@ impl UserAgent { } } + /// returns the [`UserAgentKind`] used by the [`UserAgent`], if known. + pub fn ua_kind(&self) -> Option { + match self.http_agent_overwrite { + Some(HttpAgent::Chromium) => Some(UserAgentKind::Chromium), + Some(HttpAgent::Safari) => Some(UserAgentKind::Safari), + Some(HttpAgent::Firefox) => Some(UserAgentKind::Firefox), + Some(HttpAgent::Preserve) => None, + None => match &self.data { + UserAgentData::Standard { + info: UserAgentInfo { kind, .. }, + .. + } => Some(*kind), + _ => None, + }, + } + } + /// returns the [`PlatformKind`] used by the [`UserAgent`], if known. /// /// This is the platform the [`UserAgent`] is running on. @@ -171,14 +180,13 @@ impl FromStr for UserAgent { /// The kind of [`UserAgent`] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(u8)] pub enum UserAgentKind { /// Chromium Browser - Chromium = 0b0000_0001, + Chromium, /// Firefox Browser - Firefox = 0b0000_0010, + Firefox, /// Safari Browser - Safari = 0b0000_0100, + Safari, } impl UserAgentKind { @@ -233,12 +241,11 @@ impl<'de> Deserialize<'de> for UserAgentKind { /// Device on which the [`UserAgent`] operates. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(u8)] pub enum DeviceKind { /// Personal Computers - Desktop = 0b0000_0001, + Desktop, /// Phones, Tablets and other mobile devices - Mobile = 0b0000_0010, + Mobile, } impl DeviceKind { @@ -258,18 +265,17 @@ impl fmt::Display for DeviceKind { /// Platform within the [`UserAgent`] operates. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(u8)] pub enum PlatformKind { /// Windows Platform ([`Desktop`](DeviceKind::Desktop)) - Windows = 0b0000_0001, + Windows, /// MacOS Platform ([`Desktop`](DeviceKind::Desktop)) - MacOS = 0b0000_0010, + MacOS, /// Linux Platform ([`Desktop`](DeviceKind::Desktop)) - Linux = 0b0000_0100, + Linux, /// Android Platform ([`Mobile`](DeviceKind::Mobile)) - Android = 0b0001_0000, + Android, /// iOS Platform ([`Mobile`](DeviceKind::Mobile)) - IOS = 0b0010_0000, + IOS, } impl PlatformKind { @@ -282,6 +288,15 @@ impl PlatformKind { PlatformKind::IOS => "iOS", } } + + pub fn device(&self) -> DeviceKind { + match self { + PlatformKind::Windows | PlatformKind::MacOS | PlatformKind::Linux => { + DeviceKind::Desktop + } + PlatformKind::Android | PlatformKind::IOS => DeviceKind::Mobile, + } + } } impl FromStr for PlatformKind { @@ -328,19 +343,18 @@ impl fmt::Display for PlatformKind { /// Http implementation used by the [`UserAgent`] #[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[repr(u8)] pub enum HttpAgent { /// Chromium based browsers share the same http implementation - Chromium = 0b0000_0001, + Chromium, /// Firefox has its own http implementation - Firefox = 0b0000_0010, + Firefox, /// Safari also has its own http implementation - Safari = 0b0000_0100, + Safari, /// Preserve the incoming Http Agent as much as possible. /// /// For emulators this means that emulators will aim to have a /// hands-off approach to the incoming http request. - Preserve = 0b1000_0000, + Preserve, } impl HttpAgent { @@ -397,21 +411,20 @@ impl FromStr for HttpAgent { /// Tls implementation used by the [`UserAgent`] #[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[repr(u8)] pub enum TlsAgent { /// Rustls is used as a fallback for all user agents, /// that are not chromium based. - Rustls = 0b0000_0001, + Rustls, /// Boringssl is used for Chromium based user agents. - Boringssl = 0b0000_0010, + Boringssl, /// NSS is used for Firefox - Nss = 0b0000_0100, + Nss, /// Preserve the incoming TlsAgent as much as possible. /// /// For this Tls this means that emulators can try to /// preserve details of the incoming Tls connection /// such as the (Tls) Client Hello. - Preserve = 0b1000_0000, + Preserve, } impl TlsAgent { From 4ea52f00bbc7b6e086c6f45d0fc5ebca2c177c2e Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sat, 15 Feb 2025 10:25:31 +0100 Subject: [PATCH 09/39] fix lint error --- rama-ua/src/profile/db.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index 9cdde4f4..b354116c 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -17,7 +17,7 @@ impl UserAgentDatabase { pub fn insert(&mut self, profile: UserAgentProfile) { let index = self.profiles.len(); if let Some(ua_header) = profile.ua_str() { - self.map_ua_string.insert(ua_header.to_string(), index); + self.map_ua_string.insert(ua_header.to_owned(), index); } if let Some(platform) = profile.platform { From 8a4a8ba04ef3b85a6aee08aace430132eedeb530 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sat, 15 Feb 2025 23:11:16 +0100 Subject: [PATCH 10/39] start working towards UA emulation: add UserAgentProvider logic this provider will be used by the layer logic to select a UA based on the given ctx which also allows someone else to implent their own logic --- rama-ua/src/emulate/layer.rs | 1 + rama-ua/src/emulate/mod.rs | 5 ++ rama-ua/src/emulate/provider.rs | 84 +++++++++++++++++++++++++++++++++ rama-ua/src/lib.rs | 3 ++ 4 files changed, 93 insertions(+) create mode 100644 rama-ua/src/emulate/layer.rs create mode 100644 rama-ua/src/emulate/mod.rs create mode 100644 rama-ua/src/emulate/provider.rs diff --git a/rama-ua/src/emulate/layer.rs b/rama-ua/src/emulate/layer.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/rama-ua/src/emulate/layer.rs @@ -0,0 +1 @@ + diff --git a/rama-ua/src/emulate/mod.rs b/rama-ua/src/emulate/mod.rs new file mode 100644 index 00000000..bfe88300 --- /dev/null +++ b/rama-ua/src/emulate/mod.rs @@ -0,0 +1,5 @@ +mod provider; +pub use provider::UserAgentProvider; + +mod layer; +// TODO diff --git a/rama-ua/src/emulate/provider.rs b/rama-ua/src/emulate/provider.rs new file mode 100644 index 00000000..91edb42d --- /dev/null +++ b/rama-ua/src/emulate/provider.rs @@ -0,0 +1,84 @@ +use std::sync::Arc; + +use rama_core::Context; + +use crate::{UserAgentDatabase, UserAgentProfile}; + +pub trait UserAgentProvider: Send + Sync + 'static { + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile>; +} + +impl UserAgentProvider for () { + #[inline] + fn select_user_agent_profile(&self, _ctx: &Context) -> Option<&UserAgentProfile> { + None + } +} + +impl UserAgentProvider for UserAgentProfile { + #[inline] + fn select_user_agent_profile(&self, _ctx: &Context) -> Option<&UserAgentProfile> { + Some(self) + } +} + +impl UserAgentProvider for UserAgentDatabase { + #[inline] + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile> { + ctx.get().and_then(|agent| self.get(agent)) + } +} + +impl UserAgentProvider for Option

+where + P: UserAgentProvider, +{ + #[inline] + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile> { + self.as_ref().and_then(|p| p.select_user_agent_profile(ctx)) + } +} + +impl UserAgentProvider for Arc

+where + P: UserAgentProvider, +{ + #[inline] + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile> { + self.as_ref().select_user_agent_profile(ctx) + } +} + +impl UserAgentProvider for Box

+where + P: UserAgentProvider, +{ + #[inline] + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile> { + self.as_ref().select_user_agent_profile(ctx) + } +} + +macro_rules! impl_user_agent_provider_either { + ($id:ident, $($param:ident),+ $(,)?) => { + impl UserAgentProvider for ::rama_core::combinators::$id<$($param),+> + where + $( + $param: UserAgentProvider, + )+ + { + fn select_user_agent_profile( + &self, + ctx: &Context, + ) -> Option<&UserAgentProfile> { + match self { + $( + ::rama_core::combinators::$id::$param(s) => s.select_user_agent_profile(ctx), + )+ + } + } + } + }; +} + +::rama_core::combinators::impl_either!(impl_user_agent_provider_either); diff --git a/rama-ua/src/lib.rs b/rama-ua/src/lib.rs index 7c155539..4c196c6f 100644 --- a/rama-ua/src/lib.rs +++ b/rama-ua/src/lib.rs @@ -65,3 +65,6 @@ pub use ua::*; mod profile; pub use profile::*; + +mod emulate; +pub use emulate::*; From 6f0478bba5f4941a0c22f7dab580d5b394678270 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Mon, 17 Feb 2025 15:04:22 +0100 Subject: [PATCH 11/39] start to write the UA emulate layer --- rama-http-backend/src/client/conn.rs | 8 +- rama-http-backend/src/client/mod.rs | 46 +++- rama-http-types/src/conn.rs | 15 ++ rama-http-types/src/lib.rs | 2 + rama-http-types/src/proto/h2/pseudo_header.rs | 20 ++ rama-http/src/layer/ua.rs | 6 + rama-net/src/tls/client/config.rs | 23 +- rama-net/src/tls/client/mod.rs | 2 +- rama-ua/src/emulate/layer.rs | 56 ++++ rama-ua/src/emulate/mod.rs | 5 +- rama-ua/src/emulate/service.rs | 247 ++++++++++++++++++ rama-ua/src/profile/tls.rs | 47 +++- rama-ua/src/ua/info.rs | 14 +- rama-ua/src/ua/mod.rs | 72 ++++- rama-ua/src/ua/parse.rs | 6 + .../http_user_agent_classifier.rs | 1 + 16 files changed, 549 insertions(+), 21 deletions(-) create mode 100644 rama-http-types/src/conn.rs create mode 100644 rama-ua/src/emulate/service.rs diff --git a/rama-http-backend/src/client/conn.rs b/rama-http-backend/src/client/conn.rs index c94a36bd..0676ea71 100644 --- a/rama-http-backend/src/client/conn.rs +++ b/rama-http-backend/src/client/conn.rs @@ -3,7 +3,7 @@ use rama_core::{ error::{BoxError, OpaqueError}, Context, Layer, Service, }; -use rama_http_types::{dep::http_body, Request, Version}; +use rama_http_types::{conn::Http1ClientContextParams, dep::http_body, Request, Version}; use rama_net::{ client::{ConnectorService, EstablishedClientConnection}, stream::Stream, @@ -125,7 +125,11 @@ where } Version::HTTP_11 | Version::HTTP_10 | Version::HTTP_09 => { trace!(uri = %req.uri(), "create ~h1 client executor"); - let (sender, conn) = rama_http_core::client::conn::http1::handshake(io).await?; + let mut builder = rama_http_core::client::conn::http1::Builder::new(); + if let Some(params) = ctx.get::() { + builder.title_case_headers(params.title_header_case); + } + let (sender, conn) = builder.handshake(io).await?; ctx.spawn(async move { if let Err(err) = conn.await { diff --git a/rama-http-backend/src/client/mod.rs b/rama-http-backend/src/client/mod.rs index bd6aebfa..80fb1a4c 100644 --- a/rama-http-backend/src/client/mod.rs +++ b/rama-http-backend/src/client/mod.rs @@ -1,13 +1,18 @@ //! Rama HTTP client module, //! which provides the [`HttpClient`] type to serve HTTP requests. +use std::sync::Arc; + use proxy::layer::HttpProxyConnector; use rama_core::{ error::{BoxError, ErrorExt, OpaqueError}, Context, Service, }; use rama_http_types::{dep::http_body, Request, Response}; -use rama_net::client::{ConnectorService, EstablishedClientConnection}; +use rama_net::{ + client::{ConnectorService, EstablishedClientConnection}, + tls::client::ProxyClientConfig, +}; use rama_tcp::client::service::TcpConnector; #[cfg(any(feature = "rustls", feature = "boring"))] @@ -118,15 +123,29 @@ where #[cfg(any(feature = "rustls", feature = "boring"))] let connector = { - let proxy_tls_connector_data = match &self.proxy_tls_config { - Some(proxy_tls_config) => { + let proxy_tls_connector_data = match ( + ctx.get::(), + &self.proxy_tls_config, + ) { + (Some(proxy_tls_config), _) => { + trace!("create proxy tls connector using rama tls client config from ontext"); + proxy_tls_config + .0 + .as_ref() + .clone() + .try_into() + .context( + "HttpClient: create proxy tls connector data from tls config found in context", + )? + } + (None, Some(proxy_tls_config)) => { trace!("create proxy tls connector using pre-defined rama tls client config"); proxy_tls_config .clone() .try_into() .context("HttpClient: create proxy tls connector data from tls config")? } - None => { + (None, None) => { trace!("create proxy tls connector using the 'new_http_auto' constructor"); TlsConnectorData::new().context( "HttpClient: create proxy tls connector data with no application presets", @@ -138,15 +157,20 @@ where TlsConnector::tunnel(tcp_connector, None) .with_connector_data(proxy_tls_connector_data), ); - let tls_connector_data = match &self.tls_config { - Some(tls_config) => { + let tls_connector_data = match (ctx.get::>(), &self.tls_config) { + (Some(tls_config), _) => { + trace!("create tls connector using rama tls client config from ontext"); + tls_config.as_ref().clone().try_into().context( + "HttpClient: create tls connector data from tls config found in context", + )? + } + (None, Some(tls_config)) => { trace!("create tls connector using pre-defined rama tls client config"); - tls_config - .clone() - .try_into() - .context("HttpClient: create tls connector data from tls config")? + tls_config.clone().try_into().context( + "HttpClient: create tls connector data from pre-defined tls config", + )? } - None => { + (None, None) => { trace!("create tls connector using the 'new_http_auto' constructor"); TlsConnectorData::new_http_auto() .context("HttpClient: create tls connector data for http (auto)")? diff --git a/rama-http-types/src/conn.rs b/rama-http-types/src/conn.rs new file mode 100644 index 00000000..0f662cc5 --- /dev/null +++ b/rama-http-types/src/conn.rs @@ -0,0 +1,15 @@ +//! HTTP connection utilities. + +#[derive(Debug, Clone, Default)] +/// Optional parameters that can be set in the [`Context`] of a (h1) request +/// to customise the connection of the h1 connection. +/// +/// Can be used by Http connector services, especially in the context of proxies, +/// where there might not be one static config that is to be applied to all client connections. +pub struct Http1ClientContextParams { + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Default is false. + pub title_header_case: bool, +} diff --git a/rama-http-types/src/lib.rs b/rama-http-types/src/lib.rs index cf60b2e7..b20e95ba 100644 --- a/rama-http-types/src/lib.rs +++ b/rama-http-types/src/lib.rs @@ -37,6 +37,8 @@ pub mod proto; pub mod headers; +pub mod conn; + pub mod dep { //! Dependencies for rama http modules. //! diff --git a/rama-http-types/src/proto/h2/pseudo_header.rs b/rama-http-types/src/proto/h2/pseudo_header.rs index a26a6689..5ec6487e 100644 --- a/rama-http-types/src/proto/h2/pseudo_header.rs +++ b/rama-http-types/src/proto/h2/pseudo_header.rs @@ -137,6 +137,26 @@ impl IntoIterator for PseudoHeaderOrder { } } +impl FromIterator for PseudoHeaderOrder { + fn from_iter>(iter: T) -> Self { + let mut this = Self::new(); + for header in iter { + this.push(header); + } + this + } +} + +impl<'a> FromIterator<&'a PseudoHeader> for PseudoHeaderOrder { + fn from_iter>(iter: T) -> Self { + let mut this = Self::new(); + for header in iter { + this.push(*header); + } + this + } +} + #[derive(Debug)] /// Iterator over a copy of [`PseudoHeaderOrder`]. pub struct PseudoHeaderOrderIter { diff --git a/rama-http/src/layer/ua.rs b/rama-http/src/layer/ua.rs index a34e6618..b1f0893d 100644 --- a/rama-http/src/layer/ua.rs +++ b/rama-http/src/layer/ua.rs @@ -148,6 +148,9 @@ where if let Some(preserve_ua) = overwrites.preserve_ua { ua.with_preserve_ua_header(preserve_ua); } + if let Some(req_init) = overwrites.req_init { + ua.with_request_initiator(req_init); + } } } @@ -206,6 +209,7 @@ mod tests { use crate::{headers, IntoResponse, Response, StatusCode}; use rama_core::service::service_fn; use rama_core::Context; + use rama_ua::RequestInitiator; use std::convert::Infallible; #[tokio::test] @@ -310,6 +314,7 @@ mod tests { assert_eq!(ua.http_agent(), HttpAgent::Safari); assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); assert!(ua.preserve_ua_header()); + assert_eq!(ua.request_initiator(), Some(RequestInitiator::Xhr)); Ok(StatusCode::OK.into_response()) } @@ -327,6 +332,7 @@ mod tests { http: Some(HttpAgent::Safari), tls: Some(TlsAgent::Boringssl), preserve_ua: Some(true), + req_init: Some(RequestInitiator::Xhr), }) .unwrap(), ) diff --git a/rama-net/src/tls/client/config.rs b/rama-net/src/tls/client/config.rs index a938a49a..b9e46d83 100644 --- a/rama-net/src/tls/client/config.rs +++ b/rama-net/src/tls/client/config.rs @@ -1,5 +1,15 @@ +use std::sync::Arc; + use super::{merge_client_hello_lists, ClientHelloExtension}; -use crate::tls::{CipherSuite, CompressionAlgorithm, DataEncoding, KeyLogIntent}; +use crate::tls::{CipherSuite, CompressionAlgorithm, DataEncoding, KeyLogIntent, ProtocolVersion}; + +#[derive(Debug, Clone, Default)] +/// Common API to configure a Proxy TLS Client +/// +/// See [`ClientConfig`] for more information, +/// this is only a new-type wrapper to be able to differentiate +/// the info found in context for a dynamic https client. +pub struct ProxyClientConfig(pub Arc); #[derive(Debug, Clone, Default)] /// Common API to configure a TLS Client @@ -96,3 +106,14 @@ impl From for ClientConfig { } } } + +impl From for super::ClientHello { + fn from(value: ClientConfig) -> Self { + super::ClientHello { + protocol_version: ProtocolVersion::TLSv1_2, + cipher_suites: value.cipher_suites.unwrap_or_default(), + compression_algorithms: value.compression_algorithms.unwrap_or_default(), + extensions: value.extensions.unwrap_or_default(), + } + } +} diff --git a/rama-net/src/tls/client/mod.rs b/rama-net/src/tls/client/mod.rs index 17207d61..d0c2d39e 100644 --- a/rama-net/src/tls/client/mod.rs +++ b/rama-net/src/tls/client/mod.rs @@ -20,7 +20,7 @@ pub(crate) use parser::parse_client_hello; mod config; #[doc(inline)] -pub use config::{ClientAuth, ClientAuthData, ClientConfig, ServerVerifyMode}; +pub use config::{ClientAuth, ClientAuthData, ClientConfig, ProxyClientConfig, ServerVerifyMode}; use super::{ApplicationProtocol, DataEncoding, ProtocolVersion}; diff --git a/rama-ua/src/emulate/layer.rs b/rama-ua/src/emulate/layer.rs index 8b137891..35f89f3b 100644 --- a/rama-ua/src/emulate/layer.rs +++ b/rama-ua/src/emulate/layer.rs @@ -1 +1,57 @@ +use std::fmt; +use rama_core::Layer; + +pub struct UserAgentEmulateLayer

{ + provider: P, + optional: bool, +} + +impl fmt::Debug for UserAgentEmulateLayer

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UserAgentEmulateLayer") + .field("provider", &self.provider) + .field("optional", &self.optional) + .finish() + } +} + +impl Clone for UserAgentEmulateLayer

{ + fn clone(&self) -> Self { + Self { + provider: self.provider.clone(), + optional: self.optional, + } + } +} + +impl

UserAgentEmulateLayer

{ + pub fn new(provider: P) -> Self { + Self { + provider, + optional: false, + } + } + + /// When no user agent profile was found it will + /// fail the request unless optional is true. In case of + /// the latter the service will do nothing. + pub fn optional(mut self, optional: bool) -> Self { + self.optional = optional; + self + } + + /// See [`Self::optional`]. + pub fn set_optional(&mut self, optional: bool) -> &mut Self { + self.optional = optional; + self + } +} + +impl Layer for UserAgentEmulateLayer

{ + type Service = super::UserAgentEmulateService; + + fn layer(&self, inner: S) -> Self::Service { + super::UserAgentEmulateService::new(inner, self.provider.clone()).optional(self.optional) + } +} diff --git a/rama-ua/src/emulate/mod.rs b/rama-ua/src/emulate/mod.rs index bfe88300..1ffc74c0 100644 --- a/rama-ua/src/emulate/mod.rs +++ b/rama-ua/src/emulate/mod.rs @@ -2,4 +2,7 @@ mod provider; pub use provider::UserAgentProvider; mod layer; -// TODO +pub use layer::UserAgentEmulateLayer; + +mod service; +pub use service::UserAgentEmulateService; diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs new file mode 100644 index 00000000..65327de5 --- /dev/null +++ b/rama-ua/src/emulate/service.rs @@ -0,0 +1,247 @@ +use std::fmt; + +use rama_core::{ + error::{BoxError, OpaqueError}, + Context, Service, +}; +use rama_http_types::{ + conn::Http1ClientContextParams, + header::CONTENT_TYPE, + proto::{h1::Http1HeaderMap, h2::PseudoHeaderOrder}, + HeaderName, Method, Request, Version, +}; +use rama_utils::macros::match_ignore_ascii_case_str; + +use crate::{RequestInitiator, UserAgent, UserAgentProfile}; + +use super::UserAgentProvider; + +pub struct UserAgentEmulateService { + inner: S, + provider: P, + optional: bool, +} + +impl fmt::Debug for UserAgentEmulateService { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UserAgentEmulateService") + .field("inner", &self.inner) + .field("provider", &self.provider) + .field("optional", &self.optional) + .finish() + } +} + +impl Clone for UserAgentEmulateService { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + provider: self.provider.clone(), + optional: self.optional, + } + } +} + +impl UserAgentEmulateService { + pub fn new(inner: S, provider: P) -> Self { + Self { + inner, + provider, + optional: false, + } + } + + /// When no user agent profile was found it will + /// fail the request unless optional is true. In case of + /// the latter the service will do nothing. + pub fn optional(mut self, optional: bool) -> Self { + self.optional = optional; + self + } + + /// See [`Self::optional`]. + pub fn set_optional(&mut self, optional: bool) -> &mut Self { + self.optional = optional; + self + } +} + +impl Service> for UserAgentEmulateService +where + State: Clone + Send + Sync + 'static, + Body: Send + Sync + 'static, + S: Service, Error: Into>, + P: UserAgentProvider, +{ + type Response = S::Response; + type Error = BoxError; + + async fn serve( + &self, + mut ctx: Context, + req: Request, + ) -> Result { + let profile = match self.provider.select_user_agent_profile(&ctx) { + Some(profile) => profile, + None => { + return if self.optional { + self.inner.serve(ctx, req).await.map_err(Into::into) + } else { + Err(OpaqueError::from_display( + "requirement not fulfilled: user agent profile could not be selected", + ) + .into_boxed()) + }; + } + }; + + tracing::debug!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "user agent profile selected for emulation" + ); + + emulate_http_settings(&mut ctx, &req, profile); + let _base_http_headers = get_base_http_headers(&ctx, &req, profile); + + // TODO: merge base headers with incoming headers... allowing some to overwrite, others not... + // also allowing anyway to overwrite if something is set... + + #[cfg(feature = "tls")] + { + // this Arc is to be lazilly cloned by a tls connector + // only when a connection is to be made, as to play nicely + // with concepts such as connection pooling + ctx.insert(profile.tls.client_config.clone()); + } + + #[allow(clippy::todo)] + { + todo!() + } + } +} + +fn emulate_http_settings( + ctx: &mut Context, + req: &Request, + profile: &UserAgentProfile, +) { + match req.version() { + Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "UA emulation add http1-specific settings", + ); + ctx.insert(Http1ClientContextParams { + title_header_case: profile.http.h1.title_case_headers, + }); + } + Version::HTTP_2 => { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "UA emulation add h2-specific settings", + ); + ctx.insert(PseudoHeaderOrder::from_iter( + profile.http.h2.http_pseudo_headers.iter(), + )); + } + Version::HTTP_3 => tracing::debug!( + "UA emulation not yet supported for h3: not applying anything h3-specific" + ), + _ => tracing::debug!( + version = ?req.version(), + "UA emulation not supported for unknown http version: not applying anything version-specific", + ), + } +} + +fn get_base_http_headers<'a, Body, State>( + ctx: &Context, + req: &Request, + profile: &'a UserAgentProfile, +) -> &'a Http1HeaderMap { + match ctx.get::().and_then(|ua| ua.request_initiator()) { + Some(req_init) => { + tracing::trace!(%req_init, "base http headers defined based on hint from UserAgent (overwrite)"); + get_base_http_headers_from_req_init(req_init, profile) + } + // NOTE: the primitive checks below are pretty bad, + // feel free to help improve. Just need to make sure it has good enough fallbacks, + // and that they are cheap enough to check. + None => match *req.method() { + Method::GET => { + tracing::trace!("base http headers defined based on Get=Navigate assumption"); + &profile.http.headers.navigate + } + Method::POST => { + let req_init = req + .headers() + .get(CONTENT_TYPE) + .and_then(|ct| ct.to_str().ok()) + .and_then(|s| { + match_ignore_ascii_case_str! { + match (s) { + "form-" => Some(RequestInitiator::Form), + _ => None, + } + } + }) + .unwrap_or(RequestInitiator::Fetch); + tracing::trace!(%req_init, "base http headers defined based on Post=FormOrFetch assumption"); + get_base_http_headers_from_req_init(req_init, profile) + } + _ => { + let req_init = req + .headers() + .get(HeaderName::from_static("x-requested-with")) + .and_then(|ct| ct.to_str().ok()) + .and_then(|s| { + match_ignore_ascii_case_str! { + match (s) { + "XmlHttpRequest" => Some(RequestInitiator::Xhr), + _ => None, + } + } + }) + .unwrap_or(RequestInitiator::Fetch); + tracing::trace!(%req_init, "base http headers defined based on XhrOrFetch assumption"); + get_base_http_headers_from_req_init(req_init, profile) + } + }, + } +} + +fn get_base_http_headers_from_req_init( + req_init: RequestInitiator, + profile: &UserAgentProfile, +) -> &Http1HeaderMap { + match req_init { + RequestInitiator::Navigate => &profile.http.headers.navigate, + RequestInitiator::Form => profile + .http + .headers + .form + .as_ref() + .unwrap_or(&profile.http.headers.navigate), + RequestInitiator::Xhr => profile + .http + .headers + .xhr + .as_ref() + .or(profile.http.headers.fetch.as_ref()) + .unwrap_or(&profile.http.headers.navigate), + RequestInitiator::Fetch => profile + .http + .headers + .fetch + .as_ref() + .or(profile.http.headers.xhr.as_ref()) + .unwrap_or(&profile.http.headers.navigate), + } +} diff --git a/rama-ua/src/profile/tls.rs b/rama-ua/src/profile/tls.rs index 1d88ba69..240bddde 100644 --- a/rama-ua/src/profile/tls.rs +++ b/rama-ua/src/profile/tls.rs @@ -1,7 +1,48 @@ -use rama_net::tls::client::ClientHello; +use std::sync::Arc; + +use rama_net::tls::client::{ClientConfig, ClientHello, ServerVerifyMode}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Deserialize, Serialize, Hash)] +#[derive(Debug, Clone)] pub struct TlsProfile { - pub client_hello: ClientHello, + pub client_config: Arc, +} + +impl<'de> Deserialize<'de> for TlsProfile { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let input = TlsProfileSerde::deserialize(deserializer)?; + let mut cfg = ClientConfig::from(input.client_hello); + if input.insecure { + cfg.server_verify_mode = Some(ServerVerifyMode::Disable); + } + Ok(Self { + client_config: Arc::new(cfg), + }) + } +} + +impl Serialize for TlsProfile { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let insecure = matches!( + self.client_config.server_verify_mode, + Some(ServerVerifyMode::Disable) + ); + TlsProfileSerde { + client_hello: self.client_config.as_ref().clone().into(), + insecure, + } + .serialize(serializer) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct TlsProfileSerde { + client_hello: ClientHello, + insecure: bool, } diff --git a/rama-ua/src/ua/info.rs b/rama-ua/src/ua/info.rs index fa5e227b..2d169b64 100644 --- a/rama-ua/src/ua/info.rs +++ b/rama-ua/src/ua/info.rs @@ -1,4 +1,4 @@ -use super::parse_http_user_agent_header; +use super::{parse_http_user_agent_header, RequestInitiator}; use rama_core::error::OpaqueError; use rama_utils::macros::match_ignore_ascii_case_str; use serde::{Deserialize, Deserializer, Serialize}; @@ -14,6 +14,7 @@ pub struct UserAgent { pub(super) http_agent_overwrite: Option, pub(super) tls_agent_overwrite: Option, pub(super) preserve_ua_header: bool, + pub(super) request_initiator: Option, } impl fmt::Display for UserAgent { @@ -76,6 +77,17 @@ impl UserAgent { self.preserve_ua_header } + /// Define the [`RequestInitiator`] hint. + pub fn with_request_initiator(&mut self, req_init: RequestInitiator) -> &mut Self { + self.request_initiator = Some(req_init); + self + } + + /// returns the [`RequestInitiator`] hint if available. + pub fn request_initiator(&self) -> Option { + self.request_initiator + } + /// returns the `User-Agent` (header) value used by the [`UserAgent`]. pub fn header_str(&self) -> &str { &self.header diff --git a/rama-ua/src/ua/mod.rs b/rama-ua/src/ua/mod.rs index d397001c..73201c3a 100644 --- a/rama-ua/src/ua/mod.rs +++ b/rama-ua/src/ua/mod.rs @@ -1,4 +1,7 @@ -use serde::{Deserialize, Serialize}; +use rama_core::error::OpaqueError; +use rama_utils::macros::match_ignore_ascii_case_str; +use serde::{Deserialize, Deserializer, Serialize}; +use std::{fmt, str::FromStr}; mod info; pub use info::{ @@ -27,6 +30,73 @@ pub struct UserAgentOverwrites { pub tls: Option, /// Preserve the original [`UserAgent`] header of the http `Request`. pub preserve_ua: Option, + /// Hint a specific request intiator for UA Emulation. A related + /// or default initiator might be chosen in case the hinted one is not available. + /// + /// In case this hint is not specified it will be gussed for you instead based + /// on the request method and headers. + pub req_init: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum RequestInitiator { + Navigate, + Form, + Xhr, + Fetch, +} + +impl RequestInitiator { + pub fn as_str(&self) -> &'static str { + match self { + RequestInitiator::Navigate => "navigate", + RequestInitiator::Form => "form", + RequestInitiator::Xhr => "xhr", + RequestInitiator::Fetch => "fetch", + } + } +} + +impl fmt::Display for RequestInitiator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl Serialize for RequestInitiator { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for RequestInitiator { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = >::deserialize(deserializer)?; + s.parse::() + .map_err(serde::de::Error::custom) + } +} + +impl FromStr for RequestInitiator { + type Err = OpaqueError; + + fn from_str(s: &str) -> Result { + match_ignore_ascii_case_str! { + match (s) { + "navigate" => Ok(RequestInitiator::Navigate), + "form" => Ok(RequestInitiator::Form), + "xhr" => Ok(RequestInitiator::Xhr), + "fetch" => Ok(RequestInitiator::Fetch), + _ => Err(OpaqueError::from_display(format!("invalid request initiator: {}", s))), + } + } + } } #[cfg(test)] diff --git a/rama-ua/src/ua/parse.rs b/rama-ua/src/ua/parse.rs index 75f2e8de..59193d72 100644 --- a/rama-ua/src/ua/parse.rs +++ b/rama-ua/src/ua/parse.rs @@ -32,6 +32,7 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, } } } @@ -74,6 +75,7 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, }; } else if contains_ignore_ascii_case(ua, "Desktop").is_some() { return UserAgent { @@ -82,6 +84,7 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, }; } else { (None, None, None) @@ -136,6 +139,7 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, }, (None, _, Some(platform)) => UserAgent { header, @@ -143,6 +147,7 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, }, (None, _, None) => UserAgent { header, @@ -150,6 +155,7 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, }, } } diff --git a/tests/integration/examples/example_tests/http_user_agent_classifier.rs b/tests/integration/examples/example_tests/http_user_agent_classifier.rs index c870860c..300a47f8 100644 --- a/tests/integration/examples/example_tests/http_user_agent_classifier.rs +++ b/tests/integration/examples/example_tests/http_user_agent_classifier.rs @@ -68,6 +68,7 @@ async fn test_http_user_agent_classifier() { http: Some(HttpAgent::Safari), tls: Some(TlsAgent::Boringssl), preserve_ua: Some(false), + req_init: None, }) .unwrap(), ) From 2b6d3060cd59e867025c544a545f02053dad0949 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Mon, 17 Feb 2025 15:09:27 +0100 Subject: [PATCH 12/39] remove highway from dep tree: no longer used in latest version of rama-ua --- Cargo.lock | 7 ------- Cargo.toml | 1 - rama-ua/Cargo.toml | 1 - 3 files changed, 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 179c138d..0a528e1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1212,12 +1212,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "highway" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9040319a6910b901d5d49cbada4a99db52836a1b63228a05f7e2b7f8feef89b1" - [[package]] name = "home" version = "0.5.11" @@ -2521,7 +2515,6 @@ name = "rama-ua" version = "0.2.0-alpha.7" dependencies = [ "bytes", - "highway", "rama-core", "rama-http-types", "rama-net", diff --git a/Cargo.toml b/Cargo.toml index 75df8d7a..550af634 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -150,7 +150,6 @@ want = "0.3" futures-util = "0.3" futures-channel = "0.3" sha2 = "0.10.8" -highway = "1.3.0" jemallocator = { package = "tikv-jemallocator", version = "0.6" } mimalloc = { version = "0.1.39", default-features = false } diff --git a/rama-ua/Cargo.toml b/rama-ua/Cargo.toml index f3925594..764b73f4 100644 --- a/rama-ua/Cargo.toml +++ b/rama-ua/Cargo.toml @@ -19,7 +19,6 @@ tls = ["dep:rama-net", "rama-net/tls"] [dependencies] bytes = { workspace = true } -highway = { workspace = true } rama-core = { version = "0.2.0-alpha.7", path = "../rama-core" } rama-http-types = { version = "0.2.0-alpha.7", path = "../rama-http-types" } rama-net = { version = "0.2.0-alpha.7", path = "../rama-net", optional = true } From 5aa1efff7d6d814dc0f7b629046c7e9c8257eed7 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Mon, 17 Feb 2025 22:34:29 +0100 Subject: [PATCH 13/39] fix tls import from rama-net in rama-http-backend --- rama-http-backend/src/client/mod.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/rama-http-backend/src/client/mod.rs b/rama-http-backend/src/client/mod.rs index 80fb1a4c..deeb7067 100644 --- a/rama-http-backend/src/client/mod.rs +++ b/rama-http-backend/src/client/mod.rs @@ -1,6 +1,7 @@ //! Rama HTTP client module, //! which provides the [`HttpClient`] type to serve HTTP requests. +#[cfg(any(feature = "rustls", feature = "boring"))] use std::sync::Arc; use proxy::layer::HttpProxyConnector; @@ -9,17 +10,14 @@ use rama_core::{ Context, Service, }; use rama_http_types::{dep::http_body, Request, Response}; -use rama_net::{ - client::{ConnectorService, EstablishedClientConnection}, - tls::client::ProxyClientConfig, -}; +use rama_net::client::{ConnectorService, EstablishedClientConnection}; use rama_tcp::client::service::TcpConnector; #[cfg(any(feature = "rustls", feature = "boring"))] use rama_tls::std::client::{TlsConnector, TlsConnectorData}; #[cfg(any(feature = "rustls", feature = "boring"))] -use rama_net::tls::client::ClientConfig; +use rama_net::tls::client::{ClientConfig, ProxyClientConfig}; #[cfg(any(feature = "rustls", feature = "boring"))] use rama_core::error::ErrorContext; From 9ba8f4ed75eff3dcfe8fdfac0598700a970ec082 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Mon, 17 Feb 2025 22:53:27 +0100 Subject: [PATCH 14/39] fix more warnings and support preserving http and/or tls if instructed --- rama-http/src/layer/ua.rs | 29 ++++++++-------- rama-ua/src/emulate/service.rs | 63 ++++++++++++++++++++++++++-------- rama-ua/src/ua/info.rs | 12 +++---- 3 files changed, 70 insertions(+), 34 deletions(-) diff --git a/rama-http/src/layer/ua.rs b/rama-http/src/layer/ua.rs index b1f0893d..5dd8090c 100644 --- a/rama-http/src/layer/ua.rs +++ b/rama-http/src/layer/ua.rs @@ -123,22 +123,25 @@ where mut ctx: Context, req: Request, ) -> impl Future> + Send + '_ { - let mut user_agent = req - .headers() - .typed_get::() - .map(|ua| UserAgent::new(ua.to_string())); - - if let Some(overwrites) = self + let overvwrites = self .overwrite_header .as_ref() .and_then(|header| req.headers().get(header)) .map(|header| header.as_bytes()) - .and_then(|value| serde_html_form::from_bytes::(value).ok()) - { - if let Some(ua) = overwrites.ua { - user_agent = Some(UserAgent::new(ua)); - } - if let Some(ref mut ua) = user_agent { + .and_then(|value| serde_html_form::from_bytes::(value).ok()); + + let mut user_agent = overvwrites + .as_ref() + .and_then(|o| o.ua.as_ref()) + .map(UserAgent::new) + .or_else(|| { + req.headers() + .typed_get::() + .map(|ua| UserAgent::new(ua.to_string())) + }); + + if let Some(mut ua) = user_agent.take() { + if let Some(overwrites) = overvwrites { if let Some(http_agent) = overwrites.http { ua.with_http_agent(http_agent); } @@ -152,9 +155,7 @@ where ua.with_request_initiator(req_init); } } - } - if let Some(ua) = user_agent.take() { ctx.insert(ua); } diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 65327de5..37ec24b9 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -12,7 +12,7 @@ use rama_http_types::{ }; use rama_utils::macros::match_ignore_ascii_case_str; -use crate::{RequestInitiator, UserAgent, UserAgentProfile}; +use crate::{HttpAgent, RequestInitiator, UserAgent, UserAgentProfile}; use super::UserAgentProvider; @@ -102,24 +102,55 @@ where "user agent profile selected for emulation" ); - emulate_http_settings(&mut ctx, &req, profile); - let _base_http_headers = get_base_http_headers(&ctx, &req, profile); + let preserve_http = matches!( + ctx.get::() + .copied() + .or_else(|| ctx.get::().map(|ua| ua.http_agent())), + Some(HttpAgent::Preserve), + ); - // TODO: merge base headers with incoming headers... allowing some to overwrite, others not... - // also allowing anyway to overwrite if something is set... + if preserve_http { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "user agent emulation: skip http settings as http is instructed to be preserved" + ); + } else { + emulate_http_settings(&mut ctx, &req, profile); + let _base_http_headers = get_base_http_headers(&ctx, &req, profile); - #[cfg(feature = "tls")] - { - // this Arc is to be lazilly cloned by a tls connector - // only when a connection is to be made, as to play nicely - // with concepts such as connection pooling - ctx.insert(profile.tls.client_config.clone()); + // TODO: merge base headers with incoming headers... allowing some to overwrite, others not... + // also allowing anyway to overwrite if something is set... } - #[allow(clippy::todo)] + #[cfg(feature = "tls")] { - todo!() + use crate::TlsAgent; + + let preserve_tls = matches!( + ctx.get::() + .copied() + .or_else(|| ctx.get::().map(|ua| ua.tls_agent())), + Some(TlsAgent::Preserve), + ); + if preserve_tls { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "user agent emulation: skip tls settings as http is instructed to be preserved" + ); + } else { + // client_config's Arc is to be lazilly cloned by a tls connector + // only when a connection is to be made, as to play nicely + // with concepts such as connection pooling + ctx.insert(profile.tls.client_config.clone()); + } } + + // serve emulated http(s) request via inner service + self.inner.serve(ctx, req).await.map_err(Into::into) } } @@ -166,7 +197,11 @@ fn get_base_http_headers<'a, Body, State>( req: &Request, profile: &'a UserAgentProfile, ) -> &'a Http1HeaderMap { - match ctx.get::().and_then(|ua| ua.request_initiator()) { + match ctx + .get::() + .copied() + .or_else(|| ctx.get::().and_then(|ua| ua.request_initiator())) + { Some(req_init) => { tracing::trace!(%req_init, "base http headers defined based on hint from UserAgent (overwrite)"); get_base_http_headers_from_req_init(req_init, profile) diff --git a/rama-ua/src/ua/info.rs b/rama-ua/src/ua/info.rs index 2d169b64..c7d2d493 100644 --- a/rama-ua/src/ua/info.rs +++ b/rama-ua/src/ua/info.rs @@ -147,8 +147,8 @@ impl UserAgent { /// /// [`UserAgent`]: super::UserAgent pub fn http_agent(&self) -> HttpAgent { - match &self.http_agent_overwrite { - Some(agent) => agent.clone(), + match self.http_agent_overwrite { + Some(agent) => agent, None => match &self.data { UserAgentData::Standard { info, .. } => match info.kind { UserAgentKind::Chromium => HttpAgent::Chromium, @@ -166,8 +166,8 @@ impl UserAgent { /// /// [`UserAgent`]: super::UserAgent pub fn tls_agent(&self) -> TlsAgent { - match &self.tls_agent_overwrite { - Some(agent) => agent.clone(), + match self.tls_agent_overwrite { + Some(agent) => agent, None => match &self.data { UserAgentData::Standard { info, .. } => match info.kind { UserAgentKind::Chromium => TlsAgent::Boringssl, @@ -354,7 +354,7 @@ impl fmt::Display for PlatformKind { } /// Http implementation used by the [`UserAgent`] -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum HttpAgent { /// Chromium based browsers share the same http implementation Chromium, @@ -422,7 +422,7 @@ impl FromStr for HttpAgent { } /// Tls implementation used by the [`UserAgent`] -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TlsAgent { /// Rustls is used as a fallback for all user agents, /// that are not chromium based. From b45d6aa8b24ffd71ffba1ca64daa3d27607bad9d Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Tue, 18 Feb 2025 15:06:41 +0100 Subject: [PATCH 15/39] support random UA selection and cleaner select fallbacks in general --- examples/http_user_agent_classifier.rs | 4 +- rama-http/src/layer/ua.rs | 44 +++++++-- rama-ua/src/emulate/layer.rs | 26 ++++- rama-ua/src/emulate/mod.rs | 2 +- rama-ua/src/emulate/provider.rs | 23 ++++- rama-ua/src/emulate/service.rs | 31 ++++-- rama-ua/src/profile/db.rs | 80 +++++++++------ rama-ua/src/ua/info.rs | 76 ++++++++------ rama-ua/src/ua/parse.rs | 132 +++++++++++++------------ rama-ua/src/ua/parse_tests.rs | 130 ++++++++++++------------ 10 files changed, 339 insertions(+), 209 deletions(-) diff --git a/examples/http_user_agent_classifier.rs b/examples/http_user_agent_classifier.rs index 032d2bd1..6ff74870 100644 --- a/examples/http_user_agent_classifier.rs +++ b/examples/http_user_agent_classifier.rs @@ -47,8 +47,8 @@ async fn handle(ctx: Context<()>, _req: Request) -> Result "kind": ua.info().map(|info| info.kind.to_string()), "version": ua.info().and_then(|info| info.version), "platform": ua.platform().map(|p| p.to_string()), - "http_agent": ua.http_agent().to_string(), - "tls_agent": ua.tls_agent().to_string(), + "http_agent": ua.http_agent().as_ref().map(ToString::to_string), + "tls_agent": ua.tls_agent().as_ref().map(ToString::to_string), })) .into_response()) } diff --git a/rama-http/src/layer/ua.rs b/rama-http/src/layer/ua.rs index 5dd8090c..a336c87c 100644 --- a/rama-http/src/layer/ua.rs +++ b/rama-http/src/layer/ua.rs @@ -123,16 +123,16 @@ where mut ctx: Context, req: Request, ) -> impl Future> + Send + '_ { - let overvwrites = self + let overwrites = self .overwrite_header .as_ref() .and_then(|header| req.headers().get(header)) .map(|header| header.as_bytes()) .and_then(|value| serde_html_form::from_bytes::(value).ok()); - let mut user_agent = overvwrites + let mut user_agent = overwrites .as_ref() - .and_then(|o| o.ua.as_ref()) + .and_then(|o| o.ua.as_deref()) .map(UserAgent::new) .or_else(|| { req.headers() @@ -141,7 +141,7 @@ where }); if let Some(mut ua) = user_agent.take() { - if let Some(overwrites) = overvwrites { + if let Some(overwrites) = overwrites { if let Some(http_agent) = overwrites.http { ua.with_http_agent(http_agent); } @@ -241,6 +241,34 @@ mod tests { .unwrap(); } + #[tokio::test] + async fn test_user_agent_classifier_layer_ua_iphone_app() { + const UA: &str = "iPhone App/1.0"; + + async fn handle(ctx: Context, _req: Request) -> Result { + let ua: &UserAgent = ctx.get().unwrap(); + + assert_eq!(ua.header_str(), UA); + assert!(ua.info().is_none()); + assert_eq!(ua.platform(), Some(PlatformKind::IOS)); + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); + assert!(!ua.preserve_ua_header()); + assert!(ua.request_initiator().is_none()); + + Ok(StatusCode::OK.into_response()) + } + + let service = UserAgentClassifierLayer::new().layer(service_fn(handle)); + + let _ = service + .get("http://www.example.com") + .typed_header(headers::UserAgent::from_static(UA)) + .send(Context::default()) + .await + .unwrap(); + } + #[tokio::test] async fn test_user_agent_classifier_layer_ua_chrome() { const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67"; @@ -311,9 +339,9 @@ mod tests { assert_eq!(ua.header_str(), UA); assert!(ua.info().is_none()); - assert!(ua.platform().is_none()); - assert_eq!(ua.http_agent(), HttpAgent::Safari); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + assert_eq!(ua.platform(), Some(PlatformKind::IOS)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); assert!(ua.preserve_ua_header()); assert_eq!(ua.request_initiator(), Some(RequestInitiator::Xhr)); @@ -330,7 +358,7 @@ mod tests { "x-proxy-ua", serde_html_form::to_string(&UserAgentOverwrites { ua: Some(UA.to_owned()), - http: Some(HttpAgent::Safari), + http: Some(HttpAgent::Firefox), tls: Some(TlsAgent::Boringssl), preserve_ua: Some(true), req_init: Some(RequestInitiator::Xhr), diff --git a/rama-ua/src/emulate/layer.rs b/rama-ua/src/emulate/layer.rs index 35f89f3b..14858c87 100644 --- a/rama-ua/src/emulate/layer.rs +++ b/rama-ua/src/emulate/layer.rs @@ -2,9 +2,12 @@ use std::fmt; use rama_core::Layer; +use super::UserAgentSelectFallback; + pub struct UserAgentEmulateLayer

{ provider: P, optional: bool, + select_fallback: Option, } impl fmt::Debug for UserAgentEmulateLayer

{ @@ -12,6 +15,7 @@ impl fmt::Debug for UserAgentEmulateLayer

{ f.debug_struct("UserAgentEmulateLayer") .field("provider", &self.provider) .field("optional", &self.optional) + .field("select_fallback", &self.select_fallback) .finish() } } @@ -21,6 +25,7 @@ impl Clone for UserAgentEmulateLayer

{ Self { provider: self.provider.clone(), optional: self.optional, + select_fallback: self.select_fallback, } } } @@ -30,6 +35,7 @@ impl

UserAgentEmulateLayer

{ Self { provider, optional: false, + select_fallback: None, } } @@ -46,12 +52,30 @@ impl

UserAgentEmulateLayer

{ self.optional = optional; self } + + /// Choose what to do in case no profile could be selected + /// using the regular pre-conditions as specified by the provider. + pub fn select_fallback(mut self, fb: UserAgentSelectFallback) -> Self { + self.select_fallback = Some(fb); + self + } + + /// See [`Self::select_fallback`]. + pub fn set_select_fallback(&mut self, fb: UserAgentSelectFallback) -> &mut Self { + self.select_fallback = Some(fb); + self + } } impl Layer for UserAgentEmulateLayer

{ type Service = super::UserAgentEmulateService; fn layer(&self, inner: S) -> Self::Service { - super::UserAgentEmulateService::new(inner, self.provider.clone()).optional(self.optional) + let mut svc = super::UserAgentEmulateService::new(inner, self.provider.clone()) + .optional(self.optional); + if let Some(fb) = self.select_fallback { + svc.set_select_fallback(fb); + } + svc } } diff --git a/rama-ua/src/emulate/mod.rs b/rama-ua/src/emulate/mod.rs index 1ffc74c0..bb78ac71 100644 --- a/rama-ua/src/emulate/mod.rs +++ b/rama-ua/src/emulate/mod.rs @@ -1,5 +1,5 @@ mod provider; -pub use provider::UserAgentProvider; +pub use provider::{UserAgentProvider, UserAgentSelectFallback}; mod layer; pub use layer::UserAgentEmulateLayer; diff --git a/rama-ua/src/emulate/provider.rs b/rama-ua/src/emulate/provider.rs index 91edb42d..5ce04c7c 100644 --- a/rama-ua/src/emulate/provider.rs +++ b/rama-ua/src/emulate/provider.rs @@ -4,6 +4,23 @@ use rama_core::Context; use crate::{UserAgentDatabase, UserAgentProfile}; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +/// Fallback strategy that can be injected into the context +/// to customise what a provider can be requested to do +/// in case the preconditions for UA selection were not fulfilled. +/// +/// It is advised only fallback for pre-conditions and not +/// post-selection failure as the latter would be rather confusing. +/// +/// For example if you request a Chromium profile you do not expect a Firefox one. +/// However if you do not give any filters it is fair to assume a random profile is desired, +/// given those all satisfy the abscence of filters. +pub enum UserAgentSelectFallback { + #[default] + Abort, + Random, +} + pub trait UserAgentProvider: Send + Sync + 'static { fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile>; } @@ -25,7 +42,11 @@ impl UserAgentProvider for UserAgentProfile { impl UserAgentProvider for UserAgentDatabase { #[inline] fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile> { - ctx.get().and_then(|agent| self.get(agent)) + match (ctx.get(), ctx.get()) { + (Some(agent), _) => self.get(agent), + (None, Some(UserAgentSelectFallback::Random)) => self.rnd(), + (None, None | Some(UserAgentSelectFallback::Abort)) => None, + } } } diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 37ec24b9..3b7148aa 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -14,12 +14,13 @@ use rama_utils::macros::match_ignore_ascii_case_str; use crate::{HttpAgent, RequestInitiator, UserAgent, UserAgentProfile}; -use super::UserAgentProvider; +use super::{UserAgentProvider, UserAgentSelectFallback}; pub struct UserAgentEmulateService { inner: S, provider: P, optional: bool, + select_fallback: Option, } impl fmt::Debug for UserAgentEmulateService { @@ -28,6 +29,7 @@ impl fmt::Debug for UserAgentEmulateService .field("inner", &self.inner) .field("provider", &self.provider) .field("optional", &self.optional) + .field("select_fallback", &self.select_fallback) .finish() } } @@ -38,6 +40,7 @@ impl Clone for UserAgentEmulateService { inner: self.inner.clone(), provider: self.provider.clone(), optional: self.optional, + select_fallback: self.select_fallback, } } } @@ -48,6 +51,7 @@ impl UserAgentEmulateService { inner, provider, optional: false, + select_fallback: None, } } @@ -64,6 +68,19 @@ impl UserAgentEmulateService { self.optional = optional; self } + + /// Choose what to do in case no profile could be selected + /// using the regular pre-conditions as specified by the provider. + pub fn select_fallback(mut self, fb: UserAgentSelectFallback) -> Self { + self.select_fallback = Some(fb); + self + } + + /// See [`Self::select_fallback`]. + pub fn set_select_fallback(&mut self, fb: UserAgentSelectFallback) -> &mut Self { + self.select_fallback = Some(fb); + self + } } impl Service> for UserAgentEmulateService @@ -81,6 +98,10 @@ where mut ctx: Context, req: Request, ) -> Result { + if let Some(fallback) = self.select_fallback { + ctx.insert(fallback); + } + let profile = match self.provider.select_user_agent_profile(&ctx) { Some(profile) => profile, None => { @@ -103,9 +124,7 @@ where ); let preserve_http = matches!( - ctx.get::() - .copied() - .or_else(|| ctx.get::().map(|ua| ua.http_agent())), + ctx.get::().and_then(|ua| ua.http_agent()), Some(HttpAgent::Preserve), ); @@ -129,9 +148,7 @@ where use crate::TlsAgent; let preserve_tls = matches!( - ctx.get::() - .copied() - .or_else(|| ctx.get::().map(|ua| ua.tls_agent())), + ctx.get::().and_then(|ua| ua.tls_agent()), Some(TlsAgent::Preserve), ); if preserve_tls { diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index b354116c..835882e2 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -9,6 +9,7 @@ pub struct UserAgentDatabase { map_ua_string: HashMap, + map_ua_kind: HashMap>, map_platform: HashMap<(UserAgentKind, PlatformKind), Vec>, map_device: HashMap<(UserAgentKind, DeviceKind), Vec>, } @@ -20,6 +21,11 @@ impl UserAgentDatabase { self.map_ua_string.insert(ua_header.to_owned(), index); } + self.map_ua_kind + .entry(profile.ua_kind) + .or_default() + .push(index); + if let Some(platform) = profile.platform { self.map_platform .entry((profile.ua_kind, platform)) @@ -34,6 +40,14 @@ impl UserAgentDatabase { self.profiles.push(profile); } + pub fn rnd(&self) -> Option<&UserAgentProfile> { + let ua_kind = self.market_rnd_ua_kind(); + self.map_ua_kind + .get(&ua_kind) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + pub fn get(&self, ua: &UserAgent) -> Option<&UserAgentProfile> { if let Some(profile) = self .map_ua_string @@ -52,42 +66,34 @@ impl UserAgentDatabase { .and_then(|v| v.choose(&mut rand::rng())) .and_then(|idx| self.profiles.get(*idx)), // UA + Device match (e.g. firefox desktop) - None => { - let device = ua.device(); - self.map_device + None => match ua.device() { + Some(device) => self + .map_device .get(&(ua_kind, device)) .and_then(|v| v.choose(&mut rand::rng())) - .and_then(|idx| self.profiles.get(*idx)) - } + .and_then(|idx| self.profiles.get(*idx)), + None => self + .map_ua_kind + .get(&ua_kind) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)), + }, }, // Market-share Kind + Device match (e.g. chrome desktop) None => { - let device = ua.device(); - - // market share from - // https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) - let r = rand::random_range(0..=100); - let ua_kind = if r < 3 - && self - .map_device - .contains_key(&(UserAgentKind::Firefox, device)) - { - UserAgentKind::Firefox - } else if r < 18 - && self + let ua_kind = self.market_rnd_ua_kind(); + match ua.device() { + Some(device) => self .map_device - .contains_key(&(UserAgentKind::Safari, device)) - { - UserAgentKind::Safari - } else { - // ~79% x-x - UserAgentKind::Chromium - }; - - self.map_device - .get(&(ua_kind, device)) - .and_then(|v| v.choose(&mut rand::rng())) - .and_then(|idx| self.profiles.get(*idx)) + .get(&(ua_kind, device)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)), + None => self + .map_ua_kind + .get(&ua_kind) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)), + } } } } @@ -96,6 +102,20 @@ impl UserAgentDatabase { pub fn iter(&self) -> impl Iterator { self.profiles.iter() } + + /// market share from + /// https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) + fn market_rnd_ua_kind(&self) -> UserAgentKind { + let r = rand::random_range(0..=100); + if r < 3 && self.map_ua_kind.contains_key(&UserAgentKind::Firefox) { + UserAgentKind::Firefox + } else if r < 18 && self.map_ua_kind.contains_key(&UserAgentKind::Safari) { + UserAgentKind::Safari + } else { + // ~79% x-x + UserAgentKind::Chromium + } + } } impl FromIterator for UserAgentDatabase { diff --git a/rama-ua/src/ua/info.rs b/rama-ua/src/ua/info.rs index c7d2d493..1295a9f5 100644 --- a/rama-ua/src/ua/info.rs +++ b/rama-ua/src/ua/info.rs @@ -2,14 +2,14 @@ use super::{parse_http_user_agent_header, RequestInitiator}; use rama_core::error::OpaqueError; use rama_utils::macros::match_ignore_ascii_case_str; use serde::{Deserialize, Deserializer, Serialize}; -use std::{convert::Infallible, fmt, str::FromStr}; +use std::{convert::Infallible, fmt, str::FromStr, sync::Arc}; /// User Agent (UA) information. /// /// See [the module level documentation](crate) for more information. #[derive(Debug, Clone)] pub struct UserAgent { - pub(super) header: String, + pub(super) header: Arc, pub(super) data: UserAgentData, pub(super) http_agent_overwrite: Option, pub(super) tls_agent_overwrite: Option, @@ -28,13 +28,28 @@ impl fmt::Display for UserAgent { pub(super) enum UserAgentData { Standard { info: UserAgentInfo, - platform: Option, + platform_like: Option, }, Platform(PlatformKind), Device(DeviceKind), Unknown, } +#[derive(Debug, Clone)] +pub(super) enum PlatformLike { + Platform(PlatformKind), + Device(DeviceKind), +} + +impl PlatformLike { + pub(super) fn device(&self) -> DeviceKind { + match self { + PlatformLike::Platform(platform_kind) => platform_kind.device(), + PlatformLike::Device(device_kind) => *device_kind, + } + } +} + /// Information about the [`UserAgent`] #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct UserAgentInfo { @@ -46,7 +61,7 @@ pub struct UserAgentInfo { impl UserAgent { /// Create a new [`UserAgent`] from a `User-Agent` (header) value. - pub fn new(header: impl Into) -> Self { + pub fn new(header: impl Into>) -> Self { parse_http_user_agent_header(header.into()) } @@ -94,14 +109,14 @@ impl UserAgent { } /// returns the device kind of the [`UserAgent`]. - pub fn device(&self) -> DeviceKind { + pub fn device(&self) -> Option { match &self.data { - UserAgentData::Standard { platform, .. } => { - platform.map(|p| p.device()).unwrap_or(DeviceKind::Desktop) - } - UserAgentData::Platform(platform) => platform.device(), - UserAgentData::Device(kind) => *kind, - UserAgentData::Unknown => DeviceKind::Desktop, + UserAgentData::Standard { + ref platform_like, .. + } => platform_like.as_ref().map(|p| p.device()), + UserAgentData::Platform(platform) => Some(platform.device()), + UserAgentData::Device(kind) => Some(*kind), + UserAgentData::Unknown => None, } } @@ -137,7 +152,10 @@ impl UserAgent { /// This is the platform the [`UserAgent`] is running on. pub fn platform(&self) -> Option { match &self.data { - UserAgentData::Standard { platform, .. } => *platform, + UserAgentData::Standard { platform_like, .. } => match platform_like { + Some(PlatformLike::Platform(platform)) => Some(*platform), + None | Some(PlatformLike::Device(_)) => None, + }, UserAgentData::Platform(platform) => Some(*platform), _ => None, } @@ -146,17 +164,17 @@ impl UserAgent { /// returns the [`HttpAgent`] used by the [`UserAgent`]. /// /// [`UserAgent`]: super::UserAgent - pub fn http_agent(&self) -> HttpAgent { + pub fn http_agent(&self) -> Option { match self.http_agent_overwrite { - Some(agent) => agent, + Some(agent) => Some(agent), None => match &self.data { - UserAgentData::Standard { info, .. } => match info.kind { + UserAgentData::Standard { info, .. } => Some(match info.kind { UserAgentKind::Chromium => HttpAgent::Chromium, UserAgentKind::Firefox => HttpAgent::Firefox, UserAgentKind::Safari => HttpAgent::Safari, - }, - UserAgentData::Device(_) | UserAgentData::Platform(_) | UserAgentData::Unknown => { - HttpAgent::Chromium + }), + UserAgentData::Platform(_) | UserAgentData::Device(_) | UserAgentData::Unknown => { + None } }, } @@ -165,17 +183,17 @@ impl UserAgent { /// returns the [`TlsAgent`] used by the [`UserAgent`]. /// /// [`UserAgent`]: super::UserAgent - pub fn tls_agent(&self) -> TlsAgent { + pub fn tls_agent(&self) -> Option { match self.tls_agent_overwrite { - Some(agent) => agent, + Some(agent) => Some(agent), None => match &self.data { - UserAgentData::Standard { info, .. } => match info.kind { + UserAgentData::Standard { info, .. } => Some(match info.kind { UserAgentKind::Chromium => TlsAgent::Boringssl, UserAgentKind::Firefox => TlsAgent::Nss, UserAgentKind::Safari => TlsAgent::Rustls, - }, + }), UserAgentData::Device(_) | UserAgentData::Platform(_) | UserAgentData::Unknown => { - TlsAgent::Rustls + None } }, } @@ -507,9 +525,9 @@ mod tests { }) ); assert_eq!(ua.platform(), Some(PlatformKind::MacOS)); - assert_eq!(ua.device(), DeviceKind::Desktop); - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); } #[test] @@ -524,9 +542,9 @@ mod tests { }) ); assert_eq!(ua.platform(), Some(PlatformKind::MacOS)); - assert_eq!(ua.device(), DeviceKind::Desktop); - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); } #[test] diff --git a/rama-ua/src/ua/parse.rs b/rama-ua/src/ua/parse.rs index 59193d72..1770e64d 100644 --- a/rama-ua/src/ua/parse.rs +++ b/rama-ua/src/ua/parse.rs @@ -1,7 +1,9 @@ #![allow(dead_code)] +use std::sync::Arc; + use super::{ - info::{UserAgentData, UserAgentInfo}, + info::{PlatformLike, UserAgentData, UserAgentInfo}, DeviceKind, PlatformKind, UserAgent, UserAgentKind, }; @@ -20,8 +22,9 @@ const MAX_UA_LENGTH: usize = 512; /// - complete: we do not care about all the possible user agents out there, only the popular ones. /// /// That said. Do open a ticket if you find bugs or think something is missing. -pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { - let ua = header.as_str(); +pub(crate) fn parse_http_user_agent_header(header: impl Into>) -> UserAgent { + let header = header.into(); + let ua = header.as_ref(); let ua = if ua.len() > MAX_UA_LENGTH { match ua.get(..MAX_UA_LENGTH) { Some(s) => s, @@ -40,108 +43,99 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { ua }; - let (kind, kind_version, maybe_platform) = if let Some(loc) = - contains_ignore_ascii_case(ua, "Firefox") - { - let kind = UserAgentKind::Firefox; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[loc..]); - (Some(kind), kind_version, None) - } else if let Some(loc) = contains_ignore_ascii_case(ua, "Chrom") { - let kind = UserAgentKind::Chromium; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[loc..]); - (Some(kind), kind_version, None) - } else if contains_ignore_ascii_case(ua, "Safari").is_some() { - if let Some(firefox_loc) = contains_ignore_ascii_case(ua, "FxiOS") { + let (kind, kind_version, maybe_platform) = + if let Some(loc) = contains_ignore_ascii_case(ua, "Firefox") { let kind = UserAgentKind::Firefox; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[firefox_loc..]); - (Some(kind), kind_version, Some(PlatformKind::IOS)) - } else if let Some(chrome_loc) = contains_ignore_ascii_case(ua, "CriOS") { - let kind = UserAgentKind::Chromium; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[chrome_loc..]); - (Some(kind), kind_version, Some(PlatformKind::IOS)) - } else if let Some(chromium_loc) = contains_any_ignore_ascii_case(ua, &["Opera"]) { + let kind_version = parse_ua_version_firefox_and_chromium(&ua[loc..]); + (Some(kind), kind_version, None) + } else if let Some(loc) = contains_ignore_ascii_case(ua, "Chrom") { let kind = UserAgentKind::Chromium; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[chromium_loc..]); + let kind_version = parse_ua_version_firefox_and_chromium(&ua[loc..]); (Some(kind), kind_version, None) + } else if contains_ignore_ascii_case(ua, "Safari").is_some() { + if let Some(firefox_loc) = contains_ignore_ascii_case(ua, "FxiOS") { + let kind = UserAgentKind::Firefox; + let kind_version = parse_ua_version_firefox_and_chromium(&ua[firefox_loc..]); + (Some(kind), kind_version, Some(PlatformKind::IOS)) + } else if let Some(chrome_loc) = contains_ignore_ascii_case(ua, "CriOS") { + let kind = UserAgentKind::Chromium; + let kind_version = parse_ua_version_firefox_and_chromium(&ua[chrome_loc..]); + (Some(kind), kind_version, Some(PlatformKind::IOS)) + } else if let Some(chromium_loc) = contains_any_ignore_ascii_case(ua, &["Opera"]) { + let kind = UserAgentKind::Chromium; + let kind_version = parse_ua_version_firefox_and_chromium(&ua[chromium_loc..]); + (Some(kind), kind_version, None) + } else { + let kind = UserAgentKind::Safari; + let kind_version = parse_ua_version_safari(ua); + (Some(kind), kind_version, None) + } } else { - let kind = UserAgentKind::Safari; - let kind_version = parse_ua_version_safari(ua); - (Some(kind), kind_version, None) - } - } else if contains_any_ignore_ascii_case(ua, &["Mobile", "Phone", "Tablet", "Zune"]).is_some() { - return UserAgent { - header, - data: UserAgentData::Device(DeviceKind::Mobile), - http_agent_overwrite: None, - tls_agent_overwrite: None, - preserve_ua_header: false, - request_initiator: None, - }; - } else if contains_ignore_ascii_case(ua, "Desktop").is_some() { - return UserAgent { - header, - data: UserAgentData::Device(DeviceKind::Desktop), - http_agent_overwrite: None, - tls_agent_overwrite: None, - preserve_ua_header: false, - request_initiator: None, + (None, None, None) }; - } else { - (None, None, None) - }; - let maybe_platform = match maybe_platform { - Some(platform) => Some(platform), + let (maybe_platform, maybe_device) = match maybe_platform { + Some(platform) => (Some(platform), None), None => { if contains_ignore_ascii_case(ua, "Windows").is_some() { if contains_ignore_ascii_case(ua, "X11").is_some() { - None + (None, Some(DeviceKind::Mobile)) } else { - Some(PlatformKind::Windows) + (Some(PlatformKind::Windows), None) } } else if contains_ignore_ascii_case(ua, "Android").is_some() { if contains_ignore_ascii_case(ua, "iOS").is_some() { - Some(PlatformKind::IOS) + (Some(PlatformKind::IOS), None) } else { - Some(PlatformKind::Android) + (Some(PlatformKind::Android), None) } } else if contains_ignore_ascii_case(ua, "Linux").is_some() { if contains_any_ignore_ascii_case(ua, &["Mobile", "UCW"]).is_some() { - Some(PlatformKind::Android) + (Some(PlatformKind::Android), None) } else { - Some(PlatformKind::Linux) + (Some(PlatformKind::Linux), None) } } else if contains_any_ignore_ascii_case(ua, &["iOS", "iPad", "iPod", "iPhone"]) .is_some() { - Some(PlatformKind::IOS) + (Some(PlatformKind::IOS), None) } else if contains_ignore_ascii_case(ua, "Mac").is_some() { - Some(PlatformKind::MacOS) + (Some(PlatformKind::MacOS), None) } else if contains_ignore_ascii_case(ua, "Darwin").is_some() { if contains_ignore_ascii_case(ua, "86").is_some() { - Some(PlatformKind::MacOS) + (Some(PlatformKind::MacOS), None) } else { - Some(PlatformKind::IOS) + (Some(PlatformKind::IOS), None) } + } else if contains_any_ignore_ascii_case(ua, &["Mobile", "Phone", "Tablet", "Zune"]) + .is_some() + { + (None, Some(DeviceKind::Mobile)) + } else if contains_ignore_ascii_case(ua, "Desktop").is_some() { + (None, Some(DeviceKind::Desktop)) } else { - None + (None, None) } } }; - match (kind, kind_version, maybe_platform) { - (Some(kind), version, platform) => UserAgent { + match (kind, kind_version, maybe_platform, maybe_device) { + (Some(kind), version, platform, device) => UserAgent { header, data: UserAgentData::Standard { info: UserAgentInfo { kind, version }, - platform, + platform_like: match (platform, device) { + (Some(platform), _) => Some(PlatformLike::Platform(platform)), + (None, Some(device)) => Some(PlatformLike::Device(device)), + (None, None) => None, + }, }, http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, request_initiator: None, }, - (None, _, Some(platform)) => UserAgent { + (None, _, Some(platform), _) => UserAgent { header, data: UserAgentData::Platform(platform), http_agent_overwrite: None, @@ -149,7 +143,15 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { preserve_ua_header: false, request_initiator: None, }, - (None, _, None) => UserAgent { + (None, _, None, Some(device)) => UserAgent { + header, + data: UserAgentData::Device(device), + http_agent_overwrite: None, + tls_agent_overwrite: None, + preserve_ua_header: false, + request_initiator: None, + }, + (None, _, None, None) => UserAgent { header, data: UserAgentData::Unknown, http_agent_overwrite: None, diff --git a/rama-ua/src/ua/parse_tests.rs b/rama-ua/src/ua/parse_tests.rs index 944634e3..79ac8e85 100644 --- a/rama-ua/src/ua/parse_tests.rs +++ b/rama-ua/src/ua/parse_tests.rs @@ -8,22 +8,22 @@ fn test_parse_desktop_ua() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert!(ua.info().is_none()); assert_eq!(ua.platform(), None); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents do not have defaults + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -32,22 +32,22 @@ fn test_parse_too_long_ua() { let mut ua = UserAgent::new(ua_str.clone()); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), None); assert_eq!(ua.info(), None); assert_eq!(ua.platform(), None); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents do not have defaults + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -56,22 +56,22 @@ fn test_parse_windows() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!(ua.info(), None); assert_eq!(ua.platform(), Some(PlatformKind::Windows)); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents have defaults for platforms + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -80,7 +80,7 @@ fn test_parse_chrome() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), None); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -90,18 +90,18 @@ fn test_parse_chrome() { ); assert_eq!(ua.platform(), None); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + // Http/Tls agents can be linked to found UA Info + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -110,7 +110,7 @@ fn test_parse_windows_chrome() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -120,18 +120,18 @@ fn test_parse_windows_chrome() { ); assert_eq!(ua.platform(), Some(PlatformKind::Windows)); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + // Http/Tls agents can be linked to found UA Info + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -140,7 +140,7 @@ fn test_parse_desktop_chrome() { let ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -157,7 +157,7 @@ fn test_parse_desktop_chrome_with_version() { let ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -174,7 +174,7 @@ fn test_parse_windows_chrome_with_version() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -184,18 +184,18 @@ fn test_parse_windows_chrome_with_version() { ); assert_eq!(ua.platform(), Some(PlatformKind::Windows)); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + // Http/Tls agents can be linked to found UA Info + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -204,22 +204,22 @@ fn test_parse_mobile_ua() { let mut ua = UserAgent::new(*ua_str); assert_eq!(ua.header_str(), *ua_str); - assert_eq!(ua.device(), DeviceKind::Mobile); + assert_eq!(ua.device(), Some(DeviceKind::Mobile)); assert_eq!(ua.info(), None); assert_eq!(ua.platform(), None); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents do not have defaults + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } } @@ -230,24 +230,24 @@ fn test_parse_happy_path_unknown_ua() { // UA Is always stored as is. assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), None); // No information should be known about the UA. assert!(ua.info().is_none()); assert!(ua.platform().is_none()); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents do not have defaults + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -256,7 +256,7 @@ fn test_parse_happy_path_ua_macos_chrome() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -266,18 +266,18 @@ fn test_parse_happy_path_ua_macos_chrome() { ); assert_eq!(ua.platform(), Some(PlatformKind::MacOS)); - // Http/Tls - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + // Http/Tls agents can be linked to found UA Info + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] From 6c894f62f4e71d3431d6cd1148690d9664c714d8 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Tue, 18 Feb 2025 15:37:05 +0100 Subject: [PATCH 16/39] fix tests and improve UA (db) profile selection randomness --- Cargo.lock | 1 + rama-ua/Cargo.toml | 1 + rama-ua/src/profile/db.rs | 118 ++++++++++++------ .../http_user_agent_classifier.rs | 6 +- 4 files changed, 86 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0a528e1e..ce9fa535 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2515,6 +2515,7 @@ name = "rama-ua" version = "0.2.0-alpha.7" dependencies = [ "bytes", + "itertools 0.14.0", "rama-core", "rama-http-types", "rama-net", diff --git a/rama-ua/Cargo.toml b/rama-ua/Cargo.toml index 764b73f4..8dc15686 100644 --- a/rama-ua/Cargo.toml +++ b/rama-ua/Cargo.toml @@ -19,6 +19,7 @@ tls = ["dep:rama-net", "rama-net/tls"] [dependencies] bytes = { workspace = true } +itertools = { workspace = true } rama-core = { version = "0.2.0-alpha.7", path = "../rama-core" } rama-http-types = { version = "0.2.0-alpha.7", path = "../rama-http-types" } rama-net = { version = "0.2.0-alpha.7", path = "../rama-net", optional = true } diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index 835882e2..7fc03e23 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -1,3 +1,4 @@ +use itertools::Itertools as _; use rand::seq::IndexedRandom as _; use std::collections::HashMap; @@ -15,6 +16,25 @@ pub struct UserAgentDatabase { } impl UserAgentDatabase { + pub fn iter_ua_str(&self) -> impl Iterator { + self.map_ua_string.keys().map(|s| s.as_str()) + } + + pub fn iter_ua_kind(&self) -> impl Iterator { + self.map_ua_kind.keys() + } + + pub fn iter_platform(&self) -> impl Iterator { + self.map_platform + .keys() + .map(|(_, platform)| platform) + .dedup() + } + + pub fn iter_device(&self) -> impl Iterator { + self.map_device.keys().map(|(_, device)| device).dedup() + } + pub fn insert(&mut self, profile: UserAgentProfile) { let index = self.profiles.len(); if let Some(ua_header) = profile.ua_str() { @@ -57,42 +77,64 @@ impl UserAgentDatabase { return Some(profile); } - match ua.ua_kind() { - Some(ua_kind) => match ua.platform() { + match (ua.ua_kind(), ua.platform(), ua.device()) { + (Some(ua_kind), Some(platform), _) => { // UA + Platform Match (e.g. chrome windows) - Some(platform) => self - .map_platform + self.map_platform .get(&(ua_kind, platform)) .and_then(|v| v.choose(&mut rand::rng())) - .and_then(|idx| self.profiles.get(*idx)), + .and_then(|idx| self.profiles.get(*idx)) + } + (Some(ua_kind), None, Some(device)) => { // UA + Device match (e.g. firefox desktop) - None => match ua.device() { - Some(device) => self - .map_device - .get(&(ua_kind, device)) - .and_then(|v| v.choose(&mut rand::rng())) - .and_then(|idx| self.profiles.get(*idx)), - None => self - .map_ua_kind - .get(&ua_kind) - .and_then(|v| v.choose(&mut rand::rng())) - .and_then(|idx| self.profiles.get(*idx)), - }, - }, - // Market-share Kind + Device match (e.g. chrome desktop) - None => { - let ua_kind = self.market_rnd_ua_kind(); - match ua.device() { - Some(device) => self - .map_device - .get(&(ua_kind, device)) - .and_then(|v| v.choose(&mut rand::rng())) - .and_then(|idx| self.profiles.get(*idx)), - None => self - .map_ua_kind - .get(&ua_kind) - .and_then(|v| v.choose(&mut rand::rng())) - .and_then(|idx| self.profiles.get(*idx)), + self.map_device + .get(&(ua_kind, device)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + (Some(ua_kind), None, None) => { + // random profile for this UA + self.map_ua_kind + .get(&ua_kind) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + (None, Some(platform), _) => { + // NOTE: I guestimated these numbers... Feel free to help improve these + let ua_kind = match platform { + PlatformKind::Windows => self.market_rnd_ua_kind_with_shares(7, 0), + PlatformKind::MacOS => self.market_rnd_ua_kind_with_shares(9, 35), + PlatformKind::Linux => self.market_rnd_ua_kind_with_shares(22, 0), + PlatformKind::Android => self.market_rnd_ua_kind_with_shares(3, 0), + PlatformKind::IOS => self.market_rnd_ua_kind_with_shares(5, 42), + }; + self.map_platform + .get(&(ua_kind, platform)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + (None, None, device) => { + // random ua kind matching with device or not + match device { + Some(device) => { + let ua_kind = match device { + // https://gs.statcounter.com/browser-market-share/desktop/worldwide (feb 2025) + DeviceKind::Desktop => self.market_rnd_ua_kind_with_shares(7, 9), + // https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) + DeviceKind::Mobile => self.market_rnd_ua_kind_with_shares(1, 23), + }; + self.map_device + .get(&(ua_kind, device)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + None => { + let ua_kind = self.market_rnd_ua_kind(); + self.map_ua_kind + .get(&ua_kind) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } } } } @@ -103,16 +145,18 @@ impl UserAgentDatabase { self.profiles.iter() } - /// market share from - /// https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) fn market_rnd_ua_kind(&self) -> UserAgentKind { + // https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) + self.market_rnd_ua_kind_with_shares(3, 18) + } + + fn market_rnd_ua_kind_with_shares(&self, firefox: i32, safari: i32) -> UserAgentKind { let r = rand::random_range(0..=100); - if r < 3 && self.map_ua_kind.contains_key(&UserAgentKind::Firefox) { + if r < firefox && self.map_ua_kind.contains_key(&UserAgentKind::Firefox) { UserAgentKind::Firefox - } else if r < 18 && self.map_ua_kind.contains_key(&UserAgentKind::Safari) { + } else if r < safari + firefox && self.map_ua_kind.contains_key(&UserAgentKind::Safari) { UserAgentKind::Safari } else { - // ~79% x-x UserAgentKind::Chromium } } diff --git a/tests/integration/examples/example_tests/http_user_agent_classifier.rs b/tests/integration/examples/example_tests/http_user_agent_classifier.rs index 300a47f8..a1570d1f 100644 --- a/tests/integration/examples/example_tests/http_user_agent_classifier.rs +++ b/tests/integration/examples/example_tests/http_user_agent_classifier.rs @@ -35,8 +35,8 @@ async fn test_http_user_agent_classifier() { assert_eq!(ua_rama.kind, None); assert_eq!(ua_rama.version, None); assert_eq!(ua_rama.platform, None); - assert_eq!(ua_rama.http_agent, Some("Chromium".to_owned())); - assert_eq!(ua_rama.tls_agent, Some("Rustls".to_owned())); + assert_eq!(ua_rama.http_agent, None); + assert_eq!(ua_rama.tls_agent, None); const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67"; @@ -81,7 +81,7 @@ async fn test_http_user_agent_classifier() { assert_eq!(ua_app.ua, UA_APP); assert!(ua_app.kind.is_none()); assert!(ua_app.version.is_none()); - assert!(ua_app.platform.is_none()); + assert_eq!(ua_app.platform, Some("iOS".to_owned())); assert_eq!(ua_app.http_agent, Some("Safari".to_owned())); assert_eq!(ua_app.tls_agent, Some("Boringssl".to_owned())); } From a5f0c661f81df0f2ad5cf469136f558aa5ae767a Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Tue, 18 Feb 2025 18:15:34 +0100 Subject: [PATCH 17/39] support client config based on context in tls connectors --- rama-tls/src/boring/client/connector.rs | 55 ++++++++++++++++++++++--- rama-tls/src/rustls/client/connector.rs | 50 ++++++++++++++++++++-- 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/rama-tls/src/boring/client/connector.rs b/rama-tls/src/boring/client/connector.rs index eb66c427..42647528 100644 --- a/rama-tls/src/boring/client/connector.rs +++ b/rama-tls/src/boring/client/connector.rs @@ -2,15 +2,16 @@ use super::TlsConnectorData; use crate::types::TlsTunnel; use pin_project_lite::pin_project; use private::{ConnectorKindAuto, ConnectorKindSecure, ConnectorKindTunnel}; -use rama_core::error::{BoxError, ErrorExt, OpaqueError}; +use rama_core::error::{BoxError, ErrorContext, ErrorExt, OpaqueError}; use rama_core::{Context, Layer, Service}; use rama_net::address::Host; use rama_net::client::{ConnectorService, EstablishedClientConnection}; use rama_net::stream::Stream; -use rama_net::tls::client::NegotiatedTlsParameters; +use rama_net::tls::client::{ClientConfig, NegotiatedTlsParameters}; use rama_net::tls::ApplicationProtocol; use rama_net::transport::TryRefIntoTransportContext; use std::fmt; +use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_boring::SslStream; @@ -266,7 +267,21 @@ where let host = transport_ctx.authority.host().clone(); - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into boring connector data")?, + ), + None => None, + }, + }; + let (stream, negotiated_params) = self.handshake(connector_data, host, conn).await?; tracing::trace!( @@ -324,7 +339,21 @@ where let host = transport_ctx.authority.host().clone(); - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into boring connector data")?, + ), + None => None, + }, + }; + let (conn, negotiated_params) = self.handshake(connector_data, host, conn).await?; ctx.insert(negotiated_params); @@ -380,7 +409,21 @@ where } }; - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into boring connector data")?, + ), + None => None, + }, + }; + let (stream, negotiated_params) = self.handshake(connector_data, host, conn).await?; ctx.insert(negotiated_params); @@ -409,7 +452,7 @@ impl TlsConnector { let connector_data = connector_data.as_ref().or(self.connector_data.as_ref()); let client_config_data = match connector_data { Some(connector_data) => connector_data.try_to_build_config()?, - None => TlsConnectorData::new_http_auto()?.try_to_build_config()?, + None => TlsConnectorData::new()?.try_to_build_config()?, }; let server_host = client_config_data.server_name.unwrap_or(server_host); let stream = tokio_boring::connect( diff --git a/rama-tls/src/rustls/client/connector.rs b/rama-tls/src/rustls/client/connector.rs index 168061cb..65511496 100644 --- a/rama-tls/src/rustls/client/connector.rs +++ b/rama-tls/src/rustls/client/connector.rs @@ -9,7 +9,7 @@ use rama_core::{Context, Layer, Service}; use rama_net::address::Host; use rama_net::client::{ConnectorService, EstablishedClientConnection}; use rama_net::stream::Stream; -use rama_net::tls::client::NegotiatedTlsParameters; +use rama_net::tls::client::{ClientConfig, NegotiatedTlsParameters}; use rama_net::tls::ApplicationProtocol; use rama_net::transport::TryRefIntoTransportContext; use std::fmt; @@ -273,7 +273,21 @@ where "TlsConnector(auto): attempt to secure inner connection", ); - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into rustls connector data")?, + ), + None => None, + }, + }; + let (stream, negotiated_params) = self.handshake(connector_data, server_host, conn).await?; tracing::trace!( @@ -332,7 +346,21 @@ where let server_host = transport_ctx.authority.host().clone(); - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into rustls connector data")?, + ), + None => None, + }, + }; + let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?; ctx.insert(negotiated_params); @@ -388,7 +416,21 @@ where } }; - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into rustls connector data")?, + ), + None => None, + }, + }; + let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?; ctx.insert(negotiated_params); From aab34c8e1889e9f303cd3855eebc52969c7fa41a Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Wed, 19 Feb 2025 22:11:35 +0100 Subject: [PATCH 18/39] merge http headers with base headers (UA Emulation) --- rama-http-types/src/proto/h1/headers/map.rs | 9 +- rama-http-types/src/proto/h1/headers/mod.rs | 4 +- rama-ua/src/emulate/layer.rs | 30 ++++ rama-ua/src/emulate/service.rs | 159 ++++++++++++++++++-- rama-ua/src/profile/http.rs | 8 +- 5 files changed, 193 insertions(+), 17 deletions(-) diff --git a/rama-http-types/src/proto/h1/headers/map.rs b/rama-http-types/src/proto/h1/headers/map.rs index 5945a7b4..1de0deb8 100644 --- a/rama-http-types/src/proto/h1/headers/map.rs +++ b/rama-http-types/src/proto/h1/headers/map.rs @@ -232,7 +232,9 @@ impl Iterator for Http1HeaderMapIntoIter { } #[derive(Debug)] -struct HeaderMapValueRemover { +/// Utility that can be used to be able to remove +/// headers from an [`HeaderMap`] in random order, one by one. +pub struct HeaderMapValueRemover { header_map: HeaderMap, removed_values: Option>>, } @@ -247,7 +249,7 @@ impl From for HeaderMapValueRemover { } impl HeaderMapValueRemover { - fn remove(&mut self, header: &HeaderName) -> Option { + pub fn remove(&mut self, header: &HeaderName) -> Option { match self.header_map.entry(header) { header::Entry::Occupied(occupied_entry) => { let (k, mut values) = occupied_entry.remove_entry_mult(); @@ -290,7 +292,8 @@ impl IntoIterator for HeaderMapValueRemover { } #[derive(Debug)] -struct HeaderMapValueRemoverIntoIter { +/// Porduced by the [`IntoIterator`] implementation for [`HeaderMapValueRemover`]. +pub struct HeaderMapValueRemoverIntoIter { cached_header_name: Option, cached_headers: Option>>, removed_headers: diff --git a/rama-http-types/src/proto/h1/headers/mod.rs b/rama-http-types/src/proto/h1/headers/mod.rs index 9a756909..3d615465 100644 --- a/rama-http-types/src/proto/h1/headers/mod.rs +++ b/rama-http-types/src/proto/h1/headers/mod.rs @@ -12,4 +12,6 @@ pub use name::{Http1HeaderName, IntoHttp1HeaderName, TryIntoHttp1HeaderName}; pub mod original; mod map; -pub use map::{Http1HeaderMap, Http1HeaderMapIntoIter}; +pub use map::{ + HeaderMapValueRemover, HeaderMapValueRemoverIntoIter, Http1HeaderMap, Http1HeaderMapIntoIter, +}; diff --git a/rama-ua/src/emulate/layer.rs b/rama-ua/src/emulate/layer.rs index 14858c87..f1b663d2 100644 --- a/rama-ua/src/emulate/layer.rs +++ b/rama-ua/src/emulate/layer.rs @@ -1,12 +1,14 @@ use std::fmt; use rama_core::Layer; +use rama_http_types::HeaderName; use super::UserAgentSelectFallback; pub struct UserAgentEmulateLayer

{ provider: P, optional: bool, + input_header_order: Option, select_fallback: Option, } @@ -15,6 +17,7 @@ impl fmt::Debug for UserAgentEmulateLayer

{ f.debug_struct("UserAgentEmulateLayer") .field("provider", &self.provider) .field("optional", &self.optional) + .field("input_header_order", &self.input_header_order) .field("select_fallback", &self.select_fallback) .finish() } @@ -25,6 +28,7 @@ impl Clone for UserAgentEmulateLayer

{ Self { provider: self.provider.clone(), optional: self.optional, + input_header_order: self.input_header_order.clone(), select_fallback: self.select_fallback, } } @@ -35,6 +39,7 @@ impl

UserAgentEmulateLayer

{ Self { provider, optional: false, + input_header_order: None, select_fallback: None, } } @@ -53,6 +58,28 @@ impl

UserAgentEmulateLayer

{ self } + /// Define a header that if present is to contain a CSV header name list, + /// that allows you to define the desired header order for the (extra) headers + /// found in the input (http) request. + /// + /// Extra meaning any headers not considered a base header and already defined + /// by the (selected) User Agent Profile. + /// + /// This can be useful because your http client might not respect the header casing + /// and/or order of the headers taken together. Using this metadata allows you to + /// communicate this data through anyway. If however your http client does respect + /// casing and order, or you don't care about some of it, you might not need it. + pub fn input_header_order(mut self, name: HeaderName) -> Self { + self.input_header_order = Some(name); + self + } + + /// See [`Self::input_header_order`]. + pub fn set_input_header_order(&mut self, name: HeaderName) -> &mut Self { + self.input_header_order = Some(name); + self + } + /// Choose what to do in case no profile could be selected /// using the regular pre-conditions as specified by the provider. pub fn select_fallback(mut self, fb: UserAgentSelectFallback) -> Self { @@ -76,6 +103,9 @@ impl Layer for UserAgentEmulateLayer

{ if let Some(fb) = self.select_fallback { svc.set_select_fallback(fb); } + if let Some(name) = self.input_header_order.clone() { + svc.set_input_header_order(name); + } svc } } diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 3b7148aa..20f87071 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -1,18 +1,24 @@ use std::fmt; use rama_core::{ - error::{BoxError, OpaqueError}, + error::{BoxError, ErrorContext, OpaqueError}, Context, Service, }; use rama_http_types::{ conn::Http1ClientContextParams, - header::CONTENT_TYPE, - proto::{h1::Http1HeaderMap, h2::PseudoHeaderOrder}, - HeaderName, Method, Request, Version, + header::{ACCEPT, ACCEPT_LANGUAGE, CONTENT_TYPE, USER_AGENT}, + proto::{ + h1::{ + headers::{original::OriginalHttp1Headers, HeaderMapValueRemover}, + Http1HeaderMap, + }, + h2::PseudoHeaderOrder, + }, + HeaderMap, HeaderName, Method, Request, Version, }; use rama_utils::macros::match_ignore_ascii_case_str; -use crate::{HttpAgent, RequestInitiator, UserAgent, UserAgentProfile}; +use crate::{HttpAgent, RequestInitiator, UserAgent, UserAgentProfile, CUSTOM_HEADER_MARKER}; use super::{UserAgentProvider, UserAgentSelectFallback}; @@ -20,6 +26,7 @@ pub struct UserAgentEmulateService { inner: S, provider: P, optional: bool, + input_header_order: Option, select_fallback: Option, } @@ -29,6 +36,7 @@ impl fmt::Debug for UserAgentEmulateService .field("inner", &self.inner) .field("provider", &self.provider) .field("optional", &self.optional) + .field("input_header_order", &self.input_header_order) .field("select_fallback", &self.select_fallback) .finish() } @@ -40,6 +48,7 @@ impl Clone for UserAgentEmulateService { inner: self.inner.clone(), provider: self.provider.clone(), optional: self.optional, + input_header_order: self.input_header_order.clone(), select_fallback: self.select_fallback, } } @@ -51,6 +60,7 @@ impl UserAgentEmulateService { inner, provider, optional: false, + input_header_order: None, select_fallback: None, } } @@ -69,6 +79,28 @@ impl UserAgentEmulateService { self } + /// Define a header that if present is to contain a CSV header name list, + /// that allows you to define the desired header order for the (extra) headers + /// found in the input (http) request. + /// + /// Extra meaning any headers not considered a base header and already defined + /// by the (selected) User Agent Profile. + /// + /// This can be useful because your http client might not respect the header casing + /// and/or order of the headers taken together. Using this metadata allows you to + /// communicate this data through anyway. If however your http client does respect + /// casing and order, or you don't care about some of it, you might not need it. + pub fn input_header_order(mut self, name: HeaderName) -> Self { + self.input_header_order = Some(name); + self + } + + /// See [`Self::input_header_order`]. + pub fn set_input_header_order(&mut self, name: HeaderName) -> &mut Self { + self.input_header_order = Some(name); + self + } + /// Choose what to do in case no profile could be selected /// using the regular pre-conditions as specified by the provider. pub fn select_fallback(mut self, fb: UserAgentSelectFallback) -> Self { @@ -96,7 +128,7 @@ where async fn serve( &self, mut ctx: Context, - req: Request, + mut req: Request, ) -> Result { if let Some(fallback) = self.select_fallback { ctx.insert(fallback); @@ -136,11 +168,35 @@ where "user agent emulation: skip http settings as http is instructed to be preserved" ); } else { - emulate_http_settings(&mut ctx, &req, profile); - let _base_http_headers = get_base_http_headers(&ctx, &req, profile); + emulate_http_settings(&mut ctx, &mut req, profile); + let base_http_headers = get_base_http_headers(&ctx, &req, profile); + let original_http_header_order = + get_original_http_header_order(&ctx, &req, self.input_header_order.as_ref()) + .context("collect original http header order")?; + + let original_headers = req.headers().clone(); + + let preserve_ua_header = ctx + .get::() + .map(|ua| ua.preserve_ua_header()) + .unwrap_or_default(); + + let output_headers = merge_http_headers( + base_http_headers, + original_http_header_order, + original_headers, + preserve_ua_header, + )?; - // TODO: merge base headers with incoming headers... allowing some to overwrite, others not... - // also allowing anyway to overwrite if something is set... + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "user agent emulation: http settings and headers emulated" + ); + let (output_headers, original_headers) = output_headers.into_parts(); + *req.headers_mut() = output_headers; + req.extensions_mut().insert(original_headers); } #[cfg(feature = "tls")] @@ -173,7 +229,7 @@ where fn emulate_http_settings( ctx: &mut Context, - req: &Request, + req: &mut Request, profile: &UserAgentProfile, ) { match req.version() { @@ -195,7 +251,7 @@ fn emulate_http_settings( platform = ?profile.platform, "UA emulation add h2-specific settings", ); - ctx.insert(PseudoHeaderOrder::from_iter( + req.extensions_mut().insert(PseudoHeaderOrder::from_iter( profile.http.h2.http_pseudo_headers.iter(), )); } @@ -297,3 +353,82 @@ fn get_base_http_headers_from_req_init( .unwrap_or(&profile.http.headers.navigate), } } + +fn get_original_http_header_order( + ctx: &Context, + req: &Request, + input_header_order: Option<&HeaderName>, +) -> Result, OpaqueError> { + if let Some(header) = input_header_order.and_then(|name| req.headers().get(name)) { + let s = header.to_str().context("interpret header as a utf-8 str")?; + let mut headers = OriginalHttp1Headers::with_capacity(s.matches(',').count()); + for s in s.split(',') { + let s = s.trim(); + if s.is_empty() { + continue; + } + headers.push(s.parse().context("parse header part as h1 headern name")?); + } + return Ok(Some(headers)); + } + Ok(ctx.get().cloned()) +} + +fn merge_http_headers( + base_http_headers: &Http1HeaderMap, + original_http_header_order: Option, + original_headers: HeaderMap, + preserve_ua_header: bool, +) -> Result { + let mut original_headers = HeaderMapValueRemover::from(original_headers); + + let mut output_headers_a = Vec::new(); + let mut output_headers_b = Vec::new(); + + let mut output_headers_ref = &mut output_headers_a; + + // put all "base" headers in correct order, and with proper name casing + for (base_name, base_value) in base_http_headers.clone().into_iter() { + let base_header_name = base_name.header_name(); + match base_header_name { + &ACCEPT | &ACCEPT_LANGUAGE | &CONTENT_TYPE => { + let value = original_headers + .remove(base_header_name) + .unwrap_or(base_value); + output_headers_ref.push((base_name, value)); + } + &USER_AGENT if preserve_ua_header => { + let value = original_headers + .remove(base_header_name) + .unwrap_or(base_value); + output_headers_ref.push((base_name, value)); + } + _ => { + if base_header_name == CUSTOM_HEADER_MARKER { + output_headers_ref = &mut output_headers_b; + } else { + output_headers_ref.push((base_name, base_value)); + } + } + } + } + + // respect original header order of original headers where possible + for header_name in original_http_header_order.into_iter().flatten() { + if let Some(value) = original_headers.remove(header_name.header_name()) { + output_headers_a.push((header_name, value)); + } + } + + Ok(Http1HeaderMap::from_iter( + output_headers_a + .into_iter() + .chain(original_headers.into_iter()) // add all remaining original headers in any order within the right loc + .chain(output_headers_b.into_iter()), + )) +} + +// TODO: test: +// - get_base_http_headers +// - get_original_http_header_order +// - merge_http_headers diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs index 9adb82a4..719284e1 100644 --- a/rama-ua/src/profile/http.rs +++ b/rama-ua/src/profile/http.rs @@ -1,6 +1,12 @@ -use rama_http_types::proto::{h1::Http1HeaderMap, h2::PseudoHeader}; +use rama_http_types::{ + proto::{h1::Http1HeaderMap, h2::PseudoHeader}, + HeaderName, +}; use serde::{Deserialize, Serialize}; +pub static CUSTOM_HEADER_MARKER: HeaderName = + HeaderName::from_static("x-rama-custom-header-marker"); + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct HttpProfile { pub headers: HttpHeadersProfile, From 34b3691e28f798db5bd3dad9ab074381efa2a6aa Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Wed, 19 Feb 2025 22:22:37 +0100 Subject: [PATCH 19/39] minor fixes in merge_http_headers (ua emulation) --- rama-ua/src/emulate/service.rs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 20f87071..aa698332 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -6,7 +6,7 @@ use rama_core::{ }; use rama_http_types::{ conn::Http1ClientContextParams, - header::{ACCEPT, ACCEPT_LANGUAGE, CONTENT_TYPE, USER_AGENT}, + header::{ACCEPT, ACCEPT_LANGUAGE, CONTENT_TYPE, COOKIE, REFERER, USER_AGENT}, proto::{ h1::{ headers::{original::OriginalHttp1Headers, HeaderMapValueRemover}, @@ -390,18 +390,24 @@ fn merge_http_headers( // put all "base" headers in correct order, and with proper name casing for (base_name, base_value) in base_http_headers.clone().into_iter() { let base_header_name = base_name.header_name(); + let original_value = original_headers.remove(base_header_name); match base_header_name { &ACCEPT | &ACCEPT_LANGUAGE | &CONTENT_TYPE => { - let value = original_headers - .remove(base_header_name) - .unwrap_or(base_value); + let value = original_value.unwrap_or(base_value); output_headers_ref.push((base_name, value)); } - &USER_AGENT if preserve_ua_header => { - let value = original_headers - .remove(base_header_name) - .unwrap_or(base_value); - output_headers_ref.push((base_name, value)); + &REFERER | &COOKIE => { + if let Some(value) = original_value { + output_headers_ref.push((base_name, value)); + } + } + &USER_AGENT => { + if preserve_ua_header { + output_headers_ref.push((base_name, base_value)); + } else { + let value = original_value.unwrap_or(base_value); + output_headers_ref.push((base_name, value)); + } } _ => { if base_header_name == CUSTOM_HEADER_MARKER { From 5f92b0e6780d9f28502a0380cb779fa6f1610b48 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Thu, 20 Feb 2025 11:42:06 +0100 Subject: [PATCH 20/39] add UA db tests ; get + rnd work as expected --- rama-ua/src/emulate/service.rs | 45 +++--- rama-ua/src/profile/db.rs | 250 +++++++++++++++++++++++++++++++++ rama-ua/src/profile/http.rs | 4 +- rama-ua/src/ua/info.rs | 8 ++ 4 files changed, 284 insertions(+), 23 deletions(-) diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index aa698332..5221e4bc 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -234,26 +234,29 @@ fn emulate_http_settings( ) { match req.version() { Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => { - tracing::trace!( - ua_kind = %profile.ua_kind, - ua_version = ?profile.ua_version, - platform = ?profile.platform, - "UA emulation add http1-specific settings", - ); - ctx.insert(Http1ClientContextParams { - title_header_case: profile.http.h1.title_case_headers, - }); + if let Some(h1) = &profile.http.h1 { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "UA emulation add http1-specific settings", + ); + ctx.insert(Http1ClientContextParams { + title_header_case: h1.title_case_headers, + }); + } } Version::HTTP_2 => { - tracing::trace!( - ua_kind = %profile.ua_kind, - ua_version = ?profile.ua_version, - platform = ?profile.platform, - "UA emulation add h2-specific settings", - ); - req.extensions_mut().insert(PseudoHeaderOrder::from_iter( - profile.http.h2.http_pseudo_headers.iter(), - )); + if let Some(h2) = &profile.http.h2 { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "UA emulation add h2-specific settings", + ); + req.extensions_mut() + .insert(PseudoHeaderOrder::from_iter(h2.http_pseudo_headers.iter())); + } } Version::HTTP_3 => tracing::debug!( "UA emulation not yet supported for h3: not applying anything h3-specific" @@ -371,7 +374,7 @@ fn get_original_http_header_order( } return Ok(Some(headers)); } - Ok(ctx.get().cloned()) + Ok(ctx.get().or_else(|| req.extensions().get()).cloned()) } fn merge_http_headers( @@ -429,8 +432,8 @@ fn merge_http_headers( Ok(Http1HeaderMap::from_iter( output_headers_a .into_iter() - .chain(original_headers.into_iter()) // add all remaining original headers in any order within the right loc - .chain(output_headers_b.into_iter()), + .chain(original_headers) // add all remaining original headers in any order within the right loc + .chain(output_headers_b), )) } diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index 7fc03e23..05e084c7 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -16,6 +16,16 @@ pub struct UserAgentDatabase { } impl UserAgentDatabase { + #[inline] + pub fn len(&self) -> usize { + self.profiles.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.profiles.is_empty() + } + pub fn iter_ua_str(&self) -> impl Iterator { self.map_ua_string.keys().map(|s| s.as_str()) } @@ -179,3 +189,243 @@ impl FromIterator for UserAgentDatabase { db } } + +#[cfg(test)] +mod tests { + use rama_http_types::{header::USER_AGENT, proto::h1::Http1HeaderMap, HeaderValue}; + + use super::*; + + #[test] + fn test_ua_db_empty() { + let db = UserAgentDatabase::default(); + assert_eq!(db.iter().count(), 0); + assert!(db.get(&UserAgent::new("")).is_none()); + assert!(db.get(&UserAgent::new("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")).is_none()); + + let rnd = db.rnd(); + assert!(rnd.is_none()); + + assert!(db.iter_ua_str().next().is_none()); + assert!(db.iter_ua_kind().next().is_none()); + assert!(db.iter_platform().next().is_none()); + assert!(db.iter_device().next().is_none()); + } + + #[test] + fn test_ua_db_get_by_ua_str() { + let db = get_dummy_ua_db(); + + let profile = db.get(&UserAgent::new("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")).unwrap(); + assert_eq!(profile.ua_kind, UserAgentKind::Chromium); + assert_eq!(profile.ua_version, Some(120)); + assert_eq!(profile.platform, Some(PlatformKind::Windows)); + assert_eq!(profile.http.headers.navigate.get(USER_AGENT).unwrap().to_str().unwrap(), "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0"); + } + + #[test] + fn test_ua_db_get_by_ua_kind_and_device() { + let db = get_dummy_ua_db(); + let test_cases = [ + ( + "Chrome Desktop", + UserAgentKind::Chromium, + DeviceKind::Desktop, + ), + ("Chrome Mobile", UserAgentKind::Chromium, DeviceKind::Mobile), + ( + "Desktop Firefox", + UserAgentKind::Firefox, + DeviceKind::Desktop, + ), + ( + "Mobile with Firefox", + UserAgentKind::Firefox, + DeviceKind::Mobile, + ), + ( + "Safari on Desktop", + UserAgentKind::Safari, + DeviceKind::Desktop, + ), + ("mobile&safari", UserAgentKind::Safari, DeviceKind::Mobile), + ]; + + for (ua_str, ua_kind, device) in test_cases { + let profile = db.get(&UserAgent::new(ua_str)).expect(ua_str); + assert_eq!(profile.ua_kind, ua_kind); + assert!(profile + .platform + .map(|p| p.device() == device) + .unwrap_or_default()); + } + } + + #[test] + fn test_ua_db_get_by_ua_kind_and_platform() { + let db = get_dummy_ua_db(); + let test_cases = [ + ( + "Chrome Windows", + UserAgentKind::Chromium, + PlatformKind::Windows, + ), + ("MacOS Chrome", UserAgentKind::Chromium, PlatformKind::MacOS), + ( + "Chrome&Windows", + UserAgentKind::Chromium, + PlatformKind::Windows, + ), + ( + "Firefox on Windows", + UserAgentKind::Firefox, + PlatformKind::Windows, + ), + ( + "MacOS with Firefox", + UserAgentKind::Firefox, + PlatformKind::MacOS, + ), + ( + "Firefox + Linux", + UserAgentKind::Firefox, + PlatformKind::Linux, + ), + ]; + + for (ua_str, ua_kind, platform) in test_cases { + let profile = db.get(&UserAgent::new(ua_str)).expect(ua_str); + assert_eq!(profile.ua_kind, ua_kind); + assert_eq!(profile.platform, Some(platform)); + } + } + + #[test] + fn test_ua_db_get_by_ua_kind() { + let db = get_dummy_ua_db(); + let test_cases = [ + ("Firefox", UserAgentKind::Firefox), + ("Safari", UserAgentKind::Safari), + ("Chrome", UserAgentKind::Chromium), + ("Chromium", UserAgentKind::Chromium), + ]; + + for (ua_str, ua_kind) in test_cases { + let profile = db.get(&UserAgent::new(ua_str)).expect(ua_str); + assert_eq!(profile.ua_kind, ua_kind, "ua_str: {}", ua_str); + } + } + + #[test] + fn test_ua_db_get_by_device() { + let db = get_dummy_ua_db(); + let test_cases = [ + ("Desktop", DeviceKind::Desktop), + ("DESKTOP", DeviceKind::Desktop), + ("desktop", DeviceKind::Desktop), + ("Mobile", DeviceKind::Mobile), + ("MOBILE", DeviceKind::Mobile), + ("mobile", DeviceKind::Mobile), + ]; + + for (ua_str, device) in test_cases { + let profile = db.get(&UserAgent::new(ua_str)).expect(ua_str); + assert_eq!( + profile.platform.map(|p| p.device() == device), + Some(true), + "ua_str: {}", + ua_str + ); + } + } + + #[test] + fn test_ua_db_rnd() { + let db = get_dummy_ua_db(); + + let mut set = std::collections::HashSet::new(); + for _ in 0..db.len() * 100 { + let rnd = db.rnd().unwrap(); + set.insert( + rnd.http + .headers + .navigate + .get(USER_AGENT) + .expect("ua header") + .to_str() + .expect("utf-8 ua header value") + .to_owned(), + ); + } + + assert_eq!(set.len(), db.len()); + } + + fn dummy_ua_profile_from_str(s: &str) -> UserAgentProfile { + let ua = UserAgent::new(s); + UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: crate::HttpProfile { + headers: crate::HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(USER_AGENT, HeaderValue::from_str(s).unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: None, + xhr: None, + form: None, + }, + h1: None, + h2: None, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + } + } + + fn get_dummy_ua_db() -> UserAgentDatabase { + let mut db = UserAgentDatabase::default(); + + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 14_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Linux; Android 10; HD1913) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36 EdgA/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (iPhone; CPU iPhone OS 17_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 EdgiOS/120.0.0.0 Mobile/15E148 Safari/605.1.15")); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Windows NT 11.0; Win64; x64; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 14.1; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (X11; Linux x86_64; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Android 14; Mobile; rv:120.0) Gecko/120.0 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (iPhone; CPU iPhone OS 17_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) FxiOS/120.0 Mobile/15E148 Safari/605.1.15")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Safari/605.1.15")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 14_1) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Safari/605.1.15")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (iPhone; CPU iPhone OS 17_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (iPad; CPU OS 17_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15")); + + db + } +} diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs index 719284e1..ee8ee0f9 100644 --- a/rama-ua/src/profile/http.rs +++ b/rama-ua/src/profile/http.rs @@ -10,8 +10,8 @@ pub static CUSTOM_HEADER_MARKER: HeaderName = #[derive(Debug, Clone, Deserialize, Serialize)] pub struct HttpProfile { pub headers: HttpHeadersProfile, - pub h1: Http1Profile, - pub h2: Http2Profile, + pub h1: Option, + pub h2: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/rama-ua/src/ua/info.rs b/rama-ua/src/ua/info.rs index 1295a9f5..957f7f2d 100644 --- a/rama-ua/src/ua/info.rs +++ b/rama-ua/src/ua/info.rs @@ -147,6 +147,14 @@ impl UserAgent { } } + /// returns the version of the [`UserAgent`], if known. + pub fn ua_version(&self) -> Option { + match &self.data { + UserAgentData::Standard { info, .. } => info.version, + _ => None, + } + } + /// returns the [`PlatformKind`] used by the [`UserAgent`], if known. /// /// This is the platform the [`UserAgent`] is running on. From ace252aeadcea6e133471a54fa5f108229e5fde0 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Thu, 20 Feb 2025 11:46:15 +0100 Subject: [PATCH 21/39] prepare fp js script for rama custom header marker --- rama-cli/assets/script.js | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/rama-cli/assets/script.js b/rama-cli/assets/script.js index eb02762b..1bcc1ef2 100644 --- a/rama-cli/assets/script.js +++ b/rama-cli/assets/script.js @@ -25,8 +25,7 @@ async function fetchWithBackoff(url, options) { // Function to make a GET request async function makeGetRequest(url) { const headers = { - 'x-CusToM-HEADER': `rama-fp${Date.now()}`, - 'x-CusToM-HEADER-eXtRa': `rama-fpeXtRa-${Date.now()}`, + 'x-RAMA-custom-header-marker': `rama-fp${Date.now()}`, }; const options = { @@ -40,8 +39,7 @@ async function makeGetRequest(url) { // Function to make a POST request async function makePostRequest(url, number) { const headers = { - 'x-CusToM-HEADER': `rama-fp${Date.now()}`, - 'x-CusToM-HEADER-eXtRa': `rama-fpeXtRa-${Date.now()}`, + 'x-RAMA-custom-header-marker': `rama-fp${Date.now()}`, }; const body = JSON.stringify({ number }); @@ -60,8 +58,7 @@ function makeRequestWithXHR(url, method, number) { return new Promise((resolve, reject) => { const xhr = new XMLHttpRequest(); xhr.open(method, url); - xhr.setRequestHeader('x-CusToM-HEADER', `rama-fp${Date.now()}`); - xhr.setRequestHeader('x-CusToM-HEADER-eXtRa', `rama-fpeXtRa-${Date.now()}`); + xhr.setRequestHeader('x-RAMA-custom-header-marker', `rama-fp${Date.now()}`); xhr.onload = function () { if (xhr.status >= 200 && xhr.status < 300) { From 2e713f126006506f4f139262709a8c3df0f2e27a Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Thu, 20 Feb 2025 11:58:56 +0100 Subject: [PATCH 22/39] support auto detecting user agent as part of emulation layer can be useful so one doesn't require two layers if there is no use of the user agent outside of emulating --- rama-ua/src/emulate/layer.rs | 26 ++++++++++++++++++- rama-ua/src/emulate/service.rs | 46 ++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/rama-ua/src/emulate/layer.rs b/rama-ua/src/emulate/layer.rs index f1b663d2..baa98289 100644 --- a/rama-ua/src/emulate/layer.rs +++ b/rama-ua/src/emulate/layer.rs @@ -8,6 +8,7 @@ use super::UserAgentSelectFallback; pub struct UserAgentEmulateLayer

{ provider: P, optional: bool, + try_auto_detect_user_agent: bool, input_header_order: Option, select_fallback: Option, } @@ -17,6 +18,10 @@ impl fmt::Debug for UserAgentEmulateLayer

{ f.debug_struct("UserAgentEmulateLayer") .field("provider", &self.provider) .field("optional", &self.optional) + .field( + "try_auto_detect_user_agent", + &self.try_auto_detect_user_agent, + ) .field("input_header_order", &self.input_header_order) .field("select_fallback", &self.select_fallback) .finish() @@ -28,6 +33,7 @@ impl Clone for UserAgentEmulateLayer

{ Self { provider: self.provider.clone(), optional: self.optional, + try_auto_detect_user_agent: self.try_auto_detect_user_agent, input_header_order: self.input_header_order.clone(), select_fallback: self.select_fallback, } @@ -39,6 +45,7 @@ impl

UserAgentEmulateLayer

{ Self { provider, optional: false, + try_auto_detect_user_agent: false, input_header_order: None, select_fallback: None, } @@ -58,6 +65,22 @@ impl

UserAgentEmulateLayer

{ self } + /// If true, the layer will try to auto-detect the user agent from the request, + /// but only in case that info is not yet found in the context. + pub fn try_auto_detect_user_agent(mut self, try_auto_detect_user_agent: bool) -> Self { + self.try_auto_detect_user_agent = try_auto_detect_user_agent; + self + } + + /// See [`Self::try_auto_detect_user_agent`]. + pub fn set_try_auto_detect_user_agent( + &mut self, + try_auto_detect_user_agent: bool, + ) -> &mut Self { + self.try_auto_detect_user_agent = try_auto_detect_user_agent; + self + } + /// Define a header that if present is to contain a CSV header name list, /// that allows you to define the desired header order for the (extra) headers /// found in the input (http) request. @@ -99,7 +122,8 @@ impl Layer for UserAgentEmulateLayer

{ fn layer(&self, inner: S) -> Self::Service { let mut svc = super::UserAgentEmulateService::new(inner, self.provider.clone()) - .optional(self.optional); + .optional(self.optional) + .try_auto_detect_user_agent(self.try_auto_detect_user_agent); if let Some(fb) = self.select_fallback { svc.set_select_fallback(fb); } diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 5221e4bc..265de02e 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -26,6 +26,7 @@ pub struct UserAgentEmulateService { inner: S, provider: P, optional: bool, + try_auto_detect_user_agent: bool, input_header_order: Option, select_fallback: Option, } @@ -36,6 +37,10 @@ impl fmt::Debug for UserAgentEmulateService .field("inner", &self.inner) .field("provider", &self.provider) .field("optional", &self.optional) + .field( + "try_auto_detect_user_agent", + &self.try_auto_detect_user_agent, + ) .field("input_header_order", &self.input_header_order) .field("select_fallback", &self.select_fallback) .finish() @@ -48,6 +53,7 @@ impl Clone for UserAgentEmulateService { inner: self.inner.clone(), provider: self.provider.clone(), optional: self.optional, + try_auto_detect_user_agent: self.try_auto_detect_user_agent, input_header_order: self.input_header_order.clone(), select_fallback: self.select_fallback, } @@ -60,6 +66,7 @@ impl UserAgentEmulateService { inner, provider, optional: false, + try_auto_detect_user_agent: false, input_header_order: None, select_fallback: None, } @@ -79,6 +86,22 @@ impl UserAgentEmulateService { self } + /// If true, the service will try to auto-detect the user agent from the request, + /// but only in case that info is not yet found in the context. + pub fn try_auto_detect_user_agent(mut self, try_auto_detect_user_agent: bool) -> Self { + self.try_auto_detect_user_agent = try_auto_detect_user_agent; + self + } + + /// See [`Self::try_auto_detect_user_agent`]. + pub fn set_try_auto_detect_user_agent( + &mut self, + try_auto_detect_user_agent: bool, + ) -> &mut Self { + self.try_auto_detect_user_agent = try_auto_detect_user_agent; + self + } + /// Define a header that if present is to contain a CSV header name list, /// that allows you to define the desired header order for the (extra) headers /// found in the input (http) request. @@ -134,6 +157,29 @@ where ctx.insert(fallback); } + if self.try_auto_detect_user_agent && !ctx.contains::() { + match req + .headers() + .get(USER_AGENT) + .and_then(|ua| ua.to_str().ok()) + { + Some(ua_str) => { + let user_agent = UserAgent::new(ua_str); + tracing::trace!( + ua_str = %ua_str, + %user_agent, + "user agent auto-detected from request" + ); + ctx.insert(user_agent); + } + None => { + tracing::debug!( + "user agent auto-detection not possible: no user agent header present" + ); + } + } + } + let profile = match self.provider.select_user_agent_profile(&ctx) { Some(profile) => profile, None => { From c35f2bf90470d4a30f8eaf196457fa9d05386595 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Fri, 21 Feb 2025 09:28:32 +0100 Subject: [PATCH 23/39] add authorization header to an opt-in only base header --- rama-ua/src/emulate/service.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 73178d0e..da3de3b5 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -7,7 +7,7 @@ use rama_core::{ use rama_http_types::{ HeaderMap, HeaderName, Method, Request, Version, conn::Http1ClientContextParams, - header::{ACCEPT, ACCEPT_LANGUAGE, CONTENT_TYPE, COOKIE, REFERER, USER_AGENT}, + header::{ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_TYPE, COOKIE, REFERER, USER_AGENT}, proto::{ h1::{ Http1HeaderMap, @@ -445,7 +445,7 @@ fn merge_http_headers( let value = original_value.unwrap_or(base_value); output_headers_ref.push((base_name, value)); } - &REFERER | &COOKIE => { + &REFERER | &COOKIE | &AUTHORIZATION => { if let Some(value) = original_value { output_headers_ref.push((base_name, value)); } From c734b33e33e23e9622b4fdac17af2bc59b4c7006 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Fri, 21 Feb 2025 10:41:44 +0100 Subject: [PATCH 24/39] add more tests for sub routines of rama-ua emulate service so far all seems to be without any bugs found so far, still good to have these tests to prevent mistakes in future --- rama-http/src/layer/ua.rs | 8 +- rama-ua/src/emulate/service.rs | 626 +++++++++++++++++++++++++++++++-- rama-ua/src/ua/info.rs | 35 +- rama-ua/src/ua/mod.rs | 1 + rama-ua/src/ua/parse.rs | 2 +- rama-ua/src/ua/parse_tests.rs | 40 +-- 6 files changed, 646 insertions(+), 66 deletions(-) diff --git a/rama-http/src/layer/ua.rs b/rama-http/src/layer/ua.rs index 1962100d..4249486e 100644 --- a/rama-http/src/layer/ua.rs +++ b/rama-http/src/layer/ua.rs @@ -143,16 +143,16 @@ where if let Some(mut ua) = user_agent.take() { if let Some(overwrites) = overwrites { if let Some(http_agent) = overwrites.http { - ua.with_http_agent(http_agent); + ua.set_http_agent(http_agent); } if let Some(tls_agent) = overwrites.tls { - ua.with_tls_agent(tls_agent); + ua.set_tls_agent(tls_agent); } if let Some(preserve_ua) = overwrites.preserve_ua { - ua.with_preserve_ua_header(preserve_ua); + ua.set_preserve_ua_header(preserve_ua); } if let Some(req_init) = overwrites.req_init { - ua.with_request_initiator(req_init); + ua.set_request_initiator(req_init); } } diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index da3de3b5..05679ed9 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -16,9 +16,11 @@ use rama_http_types::{ h2::PseudoHeaderOrder, }, }; -use rama_utils::macros::match_ignore_ascii_case_str; -use crate::{CUSTOM_HEADER_MARKER, HttpAgent, RequestInitiator, UserAgent, UserAgentProfile}; +use crate::{ + CUSTOM_HEADER_MARKER, HttpAgent, RequestInitiator, UserAgent, UserAgentProfile, + contains_ignore_ascii_case, +}; use super::{UserAgentProvider, UserAgentSelectFallback}; @@ -319,11 +321,7 @@ fn get_base_http_headers<'a, Body, State>( req: &Request, profile: &'a UserAgentProfile, ) -> &'a Http1HeaderMap { - match ctx - .get::() - .copied() - .or_else(|| ctx.get::().and_then(|ua| ua.request_initiator())) - { + match ctx.get::().and_then(|ua| ua.request_initiator()) { Some(req_init) => { tracing::trace!(%req_init, "base http headers defined based on hint from UserAgent (overwrite)"); get_base_http_headers_from_req_init(req_init, profile) @@ -333,40 +331,43 @@ fn get_base_http_headers<'a, Body, State>( // and that they are cheap enough to check. None => match *req.method() { Method::GET => { - tracing::trace!("base http headers defined based on Get=Navigate assumption"); - &profile.http.headers.navigate + let req_init = if headers_contains_partial_value( + req.headers(), + &X_REQUESTED_WITH, + "XmlHttpRequest", + ) { + RequestInitiator::Xhr + } else { + RequestInitiator::Navigate + }; + tracing::trace!(%req_init, "base http headers defined based on Get=NavigateOrXhr assumption"); + get_base_http_headers_from_req_init(req_init, profile) } Method::POST => { - let req_init = req - .headers() - .get(CONTENT_TYPE) - .and_then(|ct| ct.to_str().ok()) - .and_then(|s| { - match_ignore_ascii_case_str! { - match (s) { - "form-" => Some(RequestInitiator::Form), - _ => None, - } - } - }) - .unwrap_or(RequestInitiator::Fetch); + let req_init = if headers_contains_partial_value( + req.headers(), + &X_REQUESTED_WITH, + "XmlHttpRequest", + ) { + RequestInitiator::Xhr + } else if headers_contains_partial_value(req.headers(), &CONTENT_TYPE, "form-") { + RequestInitiator::Form + } else { + RequestInitiator::Fetch + }; tracing::trace!(%req_init, "base http headers defined based on Post=FormOrFetch assumption"); get_base_http_headers_from_req_init(req_init, profile) } _ => { - let req_init = req - .headers() - .get(HeaderName::from_static("x-requested-with")) - .and_then(|ct| ct.to_str().ok()) - .and_then(|s| { - match_ignore_ascii_case_str! { - match (s) { - "XmlHttpRequest" => Some(RequestInitiator::Xhr), - _ => None, - } - } - }) - .unwrap_or(RequestInitiator::Fetch); + let req_init = if headers_contains_partial_value( + req.headers(), + &X_REQUESTED_WITH, + "XmlHttpRequest", + ) { + RequestInitiator::Xhr + } else { + RequestInitiator::Fetch + }; tracing::trace!(%req_init, "base http headers defined based on XhrOrFetch assumption"); get_base_http_headers_from_req_init(req_init, profile) } @@ -374,6 +375,16 @@ fn get_base_http_headers<'a, Body, State>( } } +static X_REQUESTED_WITH: HeaderName = HeaderName::from_static("x-requested-with"); + +fn headers_contains_partial_value(headers: &HeaderMap, name: &HeaderName, value: &str) -> bool { + headers + .get(name) + .and_then(|value| value.to_str().ok()) + .map(|s| contains_ignore_ascii_case(s, value).is_some()) + .unwrap_or_default() +} + fn get_base_http_headers_from_req_init( req_init: RequestInitiator, profile: &UserAgentProfile, @@ -484,6 +495,547 @@ fn merge_http_headers( } // TODO: test: -// - get_base_http_headers -// - get_original_http_header_order // - merge_http_headers + +#[cfg(test)] +mod tests { + use super::*; + + use std::{convert::Infallible, str::FromStr}; + + use itertools::Itertools as _; + use rama_core::service::service_fn; + use rama_http_types::{Body, HeaderValue, header::ETAG, proto::h1::Http1HeaderName}; + + use crate::{HttpHeadersProfile, HttpProfile}; + + #[test] + fn test_get_original_http_header_order() { + struct TestCase { + description: &'static str, + req: Request, + ctx: Option>, + expected_input_header_order: &'static str, + } + + let test_cases = [ + TestCase { + description: "empty request", + req: Request::new(Body::empty()), + ctx: None, + expected_input_header_order: "", + }, + TestCase { + description: "request original header order in req extensions", + req: { + let mut req = Request::new(Body::empty()); + req.extensions_mut() + .insert(OriginalHttp1Headers::from_iter([ + Http1HeaderName::from_str("x-REQUESTED-with").unwrap(), + Http1HeaderName::from_str("Accept").unwrap(), + ])); + req + }, + ctx: None, + expected_input_header_order: "x-REQUESTED-with,Accept", + }, + TestCase { + description: "request original header order in ctx", + req: { + let mut req = Request::new(Body::empty()); + req.headers_mut().insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("BAR"), + ); + req + }, + // ctx has precedence over req extensions + ctx: Some({ + let mut ctx = Context::default(); + ctx.insert(OriginalHttp1Headers::from_iter([ + Http1HeaderName::from_str("x-REQUESTED-with").unwrap(), + Http1HeaderName::from_str("Accept").unwrap(), + ])); + ctx + }), + expected_input_header_order: "x-REQUESTED-with,Accept", + }, + TestCase { + description: "request with headers but no original header order", + req: { + let mut req = Request::new(Body::empty()); + req.headers_mut().insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("BAR"), + ); + req + }, + ctx: None, + expected_input_header_order: "", + }, + ]; + + for test_case in test_cases { + let input_header_order = get_original_http_header_order( + test_case.ctx.as_ref().unwrap_or(&Context::default()), + &test_case.req, + None, + ) + .unwrap(); + let input_header_order_str = input_header_order + .map(|headers| { + headers + .into_iter() + .map(|header| header.to_string()) + .join(",") + }) + .unwrap_or_default(); + assert_eq!( + input_header_order_str, test_case.expected_input_header_order, + "{}", + test_case.description, + ); + } + } + + #[tokio::test] + async fn test_get_base_http_headers_profile_with_only_navigate_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + xhr: None, + fetch: None, + form: None, + }, + h1: None, + h2: None, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(), + ) + }), + ua_profile, + ); + + let req = Request::builder() + .method(Method::DELETE) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + assert_eq!(res, "navigate"); + } + + #[tokio::test] + async fn test_get_base_http_headers_profile_without_fetch_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + xhr: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("xhr").unwrap())] + .into_iter() + .collect(), + None, + )), + fetch: None, + form: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("form").unwrap())] + .into_iter() + .collect(), + None, + )), + }, + h1: None, + h2: None, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(), + ) + }), + ua_profile, + ); + + let req = Request::builder() + .method(Method::DELETE) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + assert_eq!(res, "xhr"); + } + + #[tokio::test] + async fn test_get_base_http_headers_profile_without_xhr_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("fetch").unwrap())] + .into_iter() + .collect(), + None, + )), + xhr: None, + form: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("form").unwrap())] + .into_iter() + .collect(), + None, + )), + }, + h1: None, + h2: None, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(), + ) + }), + ua_profile, + ); + + let req = Request::builder() + .method(Method::DELETE) + .header( + HeaderName::from_static("x-requested-with"), + "XmlHttpRequest", + ) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + assert_eq!(res, "fetch"); + } + + #[tokio::test] + async fn test_get_base_http_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("fetch").unwrap())] + .into_iter() + .collect(), + None, + )), + xhr: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("xhr").unwrap())] + .into_iter() + .collect(), + None, + )), + form: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("form").unwrap())] + .into_iter() + .collect(), + None, + )), + }, + h1: None, + h2: None, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(), + ) + }), + ua_profile, + ); + + struct TestCase { + description: &'static str, + method: Option, + headers: Option, + ctx: Option>, + expected: &'static str, + } + + let test_cases = [ + TestCase { + description: "GET request", + method: None, + headers: None, + ctx: None, + expected: "navigate", + }, + TestCase { + description: "GET request with XRW header", + method: None, + headers: Some( + [( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "GET request with RequestInitiator hint Navigate", + method: None, + headers: None, + ctx: Some({ + let mut ctx = Context::default(); + ctx.insert( + UserAgent::new("").with_request_initiator(RequestInitiator::Navigate), + ); + ctx + }), + expected: "navigate", + }, + TestCase { + description: "GET request with RequestInitiator hint Form", + method: None, + headers: None, + ctx: Some({ + let mut ctx = Context::default(); + ctx.insert(UserAgent::new("").with_request_initiator(RequestInitiator::Form)); + ctx + }), + expected: "form", + }, + TestCase { + description: "explicit GET request", + method: Some(Method::GET), + headers: None, + ctx: None, + expected: "navigate", + }, + TestCase { + description: "explicit POST request", + method: Some(Method::POST), + headers: None, + ctx: None, + expected: "fetch", + }, + TestCase { + description: "explicit POST request with XRW header", + method: Some(Method::POST), + headers: Some( + [( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "explicit POST request with multipart/form-data and XRW header", + method: Some(Method::POST), + headers: Some( + [ + ( + CONTENT_TYPE, + HeaderValue::from_static( + "multipart/form-data; boundary=ExampleBoundaryString", + ), + ), + ( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + ), + ] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "explicit POST request with application/x-www-form-urlencoded and XRW header", + method: Some(Method::POST), + headers: Some( + [ + ( + CONTENT_TYPE, + HeaderValue::from_static("application/x-www-form-urlencoded"), + ), + ( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + ), + ] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "explicit POST request with multipart/form-data", + method: Some(Method::POST), + headers: Some( + [( + CONTENT_TYPE, + HeaderValue::from_static( + "multipart/form-data; boundary=ExampleBoundaryString", + ), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "form", + }, + TestCase { + description: "explicit POST request with application/x-www-form-urlencoded", + method: Some(Method::POST), + headers: Some( + [( + CONTENT_TYPE, + HeaderValue::from_static("application/x-www-form-urlencoded"), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "form", + }, + TestCase { + description: "explicit DELETE request with XRW header", + method: Some(Method::DELETE), + headers: Some( + [( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "explicit DELETE request", + method: Some(Method::DELETE), + headers: None, + ctx: None, + expected: "fetch", + }, + TestCase { + description: "explicit DELETE request with RequestInitiator hint", + method: Some(Method::DELETE), + headers: None, + ctx: Some({ + let mut ctx = Context::default(); + ctx.insert(UserAgent::new("").with_request_initiator(RequestInitiator::Xhr)); + ctx + }), + expected: "xhr", + }, + ]; + + for test_case in test_cases { + let mut req = Request::builder() + .method(test_case.method.unwrap_or(Method::GET)) + .body(Body::empty()) + .unwrap(); + if let Some(headers) = test_case.headers { + req.headers_mut().extend(headers); + } + let ctx = test_case.ctx.unwrap_or_default(); + let res = ua_service.serve(ctx, req).await.unwrap(); + assert_eq!(res, test_case.expected, "{}", test_case.description); + } + } +} diff --git a/rama-ua/src/ua/info.rs b/rama-ua/src/ua/info.rs index 260bb283..c83b83c7 100644 --- a/rama-ua/src/ua/info.rs +++ b/rama-ua/src/ua/info.rs @@ -66,13 +66,25 @@ impl UserAgent { } /// Overwrite the [`HttpAgent`] advertised by the [`UserAgent`]. - pub fn with_http_agent(&mut self, http_agent: HttpAgent) -> &mut Self { + pub fn with_http_agent(mut self, http_agent: HttpAgent) -> Self { self.http_agent_overwrite = Some(http_agent); self } + /// Overwrite the [`HttpAgent`] advertised by the [`UserAgent`]. + pub fn set_http_agent(&mut self, http_agent: HttpAgent) -> &mut Self { + self.http_agent_overwrite = Some(http_agent); + self + } + + /// Overwrite the [`TlsAgent`] advertised by the [`UserAgent`]. + pub fn with_tls_agent(mut self, tls_agent: TlsAgent) -> Self { + self.tls_agent_overwrite = Some(tls_agent); + self + } + /// Overwrite the [`TlsAgent`] advertised by the [`UserAgent`]. - pub fn with_tls_agent(&mut self, tls_agent: TlsAgent) -> &mut Self { + pub fn set_tls_agent(&mut self, tls_agent: TlsAgent) -> &mut Self { self.tls_agent_overwrite = Some(tls_agent); self } @@ -81,7 +93,16 @@ impl UserAgent { /// /// This is used to indicate to emulators that they should respect the User-Agent header /// attached to this [`UserAgent`], if possible. - pub fn with_preserve_ua_header(&mut self, preserve: bool) -> &mut Self { + pub fn with_preserve_ua_header(mut self, preserve: bool) -> Self { + self.preserve_ua_header = preserve; + self + } + + /// Preserve the incoming `User-Agent` (header) value. + /// + /// This is used to indicate to emulators that they should respect the User-Agent header + /// attached to this [`UserAgent`], if possible. + pub fn set_preserve_ua_header(&mut self, preserve: bool) -> &mut Self { self.preserve_ua_header = preserve; self } @@ -93,7 +114,13 @@ impl UserAgent { } /// Define the [`RequestInitiator`] hint. - pub fn with_request_initiator(&mut self, req_init: RequestInitiator) -> &mut Self { + pub fn with_request_initiator(mut self, req_init: RequestInitiator) -> Self { + self.request_initiator = Some(req_init); + self + } + + /// Define the [`RequestInitiator`] hint. + pub fn set_request_initiator(&mut self, req_init: RequestInitiator) -> &mut Self { self.request_initiator = Some(req_init); self } diff --git a/rama-ua/src/ua/mod.rs b/rama-ua/src/ua/mod.rs index 73201c3a..259f5dd9 100644 --- a/rama-ua/src/ua/mod.rs +++ b/rama-ua/src/ua/mod.rs @@ -9,6 +9,7 @@ pub use info::{ }; mod parse; +pub(crate) use parse::contains_ignore_ascii_case; use parse::parse_http_user_agent_header; /// Information that can be used to overwrite the [`UserAgent`] of an http request. diff --git a/rama-ua/src/ua/parse.rs b/rama-ua/src/ua/parse.rs index bb18cafd..13e57e5d 100644 --- a/rama-ua/src/ua/parse.rs +++ b/rama-ua/src/ua/parse.rs @@ -186,7 +186,7 @@ fn parse_ua_version_safari(ua: &str) -> Option { }) } -fn contains_ignore_ascii_case(s: &str, sub: &str) -> Option { +pub(crate) fn contains_ignore_ascii_case(s: &str, sub: &str) -> Option { let n = sub.len(); if n > s.len() { return None; diff --git a/rama-ua/src/ua/parse_tests.rs b/rama-ua/src/ua/parse_tests.rs index ba814fe8..ed473c0a 100644 --- a/rama-ua/src/ua/parse_tests.rs +++ b/rama-ua/src/ua/parse_tests.rs @@ -17,11 +17,11 @@ fn test_parse_desktop_ua() { assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); + ua.set_tls_agent(TlsAgent::Nss); assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } @@ -41,11 +41,11 @@ fn test_parse_too_long_ua() { assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); + ua.set_tls_agent(TlsAgent::Nss); assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } @@ -65,11 +65,11 @@ fn test_parse_windows() { assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); + ua.set_tls_agent(TlsAgent::Nss); assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } @@ -95,11 +95,11 @@ fn test_parse_chrome() { assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); + ua.set_tls_agent(TlsAgent::Nss); assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } @@ -125,11 +125,11 @@ fn test_parse_windows_chrome() { assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); + ua.set_tls_agent(TlsAgent::Nss); assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } @@ -152,7 +152,7 @@ fn test_parse_desktop_chrome() { } #[test] -fn test_parse_desktop_chrome_with_version() { +fn test_parse_desktop_chrome_set_version() { let ua_str = "desktop chrome/124"; let ua = UserAgent::new(ua_str); @@ -169,7 +169,7 @@ fn test_parse_desktop_chrome_with_version() { } #[test] -fn test_parse_windows_chrome_with_version() { +fn test_parse_windows_chrome_set_version() { let ua_str = "windows chrome/124"; let mut ua = UserAgent::new(ua_str); @@ -189,11 +189,11 @@ fn test_parse_windows_chrome_with_version() { assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); + ua.set_tls_agent(TlsAgent::Nss); assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } @@ -213,11 +213,11 @@ fn test_parse_mobile_ua() { assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); + ua.set_tls_agent(TlsAgent::Nss); assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } @@ -241,11 +241,11 @@ fn test_parse_happy_path_unknown_ua() { assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); + ua.set_tls_agent(TlsAgent::Nss); assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } @@ -271,11 +271,11 @@ fn test_parse_happy_path_ua_macos_chrome() { assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); + ua.set_tls_agent(TlsAgent::Nss); assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } From 5625f88573b1f376569b8b2b36c63b2dffed6323 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Fri, 21 Feb 2025 14:27:10 +0100 Subject: [PATCH 25/39] fix test with random profiles --- rama-ua/src/profile/db.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index 379a1df6..1e1e95c1 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -356,7 +356,7 @@ mod tests { let db = get_dummy_ua_db(); let mut set = std::collections::HashSet::new(); - for _ in 0..db.len() * 100 { + for _ in 0..db.len() * 1000 { let rnd = db.rnd().unwrap(); set.insert( rnd.http From 59693dc28ab65d714b1f17e83ee35770fa44698a Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sat, 22 Feb 2025 10:31:29 +0100 Subject: [PATCH 26/39] add final unit test for rama-ua emulate service fixed bug regarding preserve UA opt-in flag, thx unit test --- rama-ua/src/emulate/service.rs | 344 ++++++++++++++++++++++++++++++++- 1 file changed, 335 insertions(+), 9 deletions(-) diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 05679ed9..c7c7d160 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -234,7 +234,7 @@ where original_http_header_order, original_headers, preserve_ua_header, - )?; + ); tracing::trace!( ua_kind = %profile.ua_kind, @@ -439,7 +439,7 @@ fn merge_http_headers( original_http_header_order: Option, original_headers: HeaderMap, preserve_ua_header: bool, -) -> Result { +) -> Http1HeaderMap { let mut original_headers = HeaderMapValueRemover::from(original_headers); let mut output_headers_a = Vec::new(); @@ -463,10 +463,10 @@ fn merge_http_headers( } &USER_AGENT => { if preserve_ua_header { - output_headers_ref.push((base_name, base_value)); - } else { let value = original_value.unwrap_or(base_value); output_headers_ref.push((base_name, value)); + } else { + output_headers_ref.push((base_name, base_value)); } } _ => { @@ -486,17 +486,14 @@ fn merge_http_headers( } } - Ok(Http1HeaderMap::from_iter( + Http1HeaderMap::from_iter( output_headers_a .into_iter() .chain(original_headers) // add all remaining original headers in any order within the right loc .chain(output_headers_b), - )) + ) } -// TODO: test: -// - merge_http_headers - #[cfg(test)] mod tests { use super::*; @@ -509,6 +506,335 @@ mod tests { use crate::{HttpHeadersProfile, HttpProfile}; + #[test] + fn test_merge_http_headers() { + struct TestCase { + description: &'static str, + base_http_headers: Vec<(&'static str, &'static str)>, + original_http_header_order: Option>, + original_headers: Vec<(&'static str, &'static str)>, + preserve_ua_header: bool, + expected: Vec<(&'static str, &'static str)>, + } + + let test_cases = [ + TestCase { + description: "empty", + base_http_headers: vec![], + original_http_header_order: None, + original_headers: vec![], + preserve_ua_header: false, + expected: vec![], + }, + TestCase { + description: "base headers only", + base_http_headers: vec![ + ("Accept", "text/html"), + ("Content-Type", "application/json"), + ], + original_http_header_order: None, + original_headers: vec![], + preserve_ua_header: false, + expected: vec![ + ("Accept", "text/html"), + ("Content-Type", "application/json"), + ], + }, + TestCase { + description: "original headers only", + base_http_headers: vec![], + original_http_header_order: None, + original_headers: vec![("accept", "text/html")], + preserve_ua_header: false, + expected: vec![("accept", "text/html")], + }, + TestCase { + description: "original and base headers, no conflicts", + base_http_headers: vec![("accept", "text/html"), ("user-agent", "python/3.10")], + original_http_header_order: None, + original_headers: vec![("content-type", "application/json")], + preserve_ua_header: false, + expected: vec![ + ("accept", "text/html"), + ("user-agent", "python/3.10"), + ("content-type", "application/json"), + ], + }, + TestCase { + description: "original and base headers, with conflicts", + base_http_headers: vec![ + ("accept", "text/html"), + ("content-type", "text/html"), + ("user-agent", "python/3.10"), + ], + original_http_header_order: Some(vec!["content-type", "user-agent"]), + original_headers: vec![ + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + preserve_ua_header: false, + expected: vec![ + ("accept", "text/html"), + ("content-type", "application/json"), + ("user-agent", "python/3.10"), + ], + }, + TestCase { + description: "original and base headers, with conflicts, preserve ua header", + base_http_headers: vec![ + ("accept", "text/html"), + ("content-type", "text/html"), + ("user-agent", "python/3.10"), + ], + original_http_header_order: Some(vec!["content-type", "user-agent"]), + original_headers: vec![ + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + preserve_ua_header: true, + expected: vec![ + ("accept", "text/html"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + }, + TestCase { + description: "no opt-in base headers defined", + base_http_headers: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 1234567890"), + ("cookie", "session=1234567890"), + ("referer", "https://example.com"), + ], + original_http_header_order: Some(vec!["content-type", "user-agent"]), + original_headers: vec![ + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + preserve_ua_header: false, + expected: vec![ + ("accept", "text/html"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + }, + TestCase { + description: "some opt-in base headers defined", + base_http_headers: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 1234567890"), + ("cookie", "session=1234567890"), + ("referer", "https://example.com"), + ], + original_http_header_order: Some(vec![ + "content-type", + "cookie", + "user-agent", + "referer", + ]), + original_headers: vec![ + ("content-type", "application/json"), + ("cookie", "foo=bar"), + ("user-agent", "php/8.0"), + ("referer", "https://ramaproxy.org"), + ], + preserve_ua_header: false, + expected: vec![ + ("accept", "text/html"), + ("cookie", "foo=bar"), + ("referer", "https://ramaproxy.org"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + }, + TestCase { + description: "all opt-in base headers defined", + base_http_headers: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 1234567890"), + ("cookie", "session=1234567890"), + ("referer", "https://example.com"), + ], + original_http_header_order: Some(vec![ + "content-type", + "cookie", + "user-agent", + "referer", + "authorization", + ]), + original_headers: vec![ + ("content-type", "application/json"), + ("cookie", "foo=bar"), + ("user-agent", "php/8.0"), + ("referer", "https://ramaproxy.org"), + ("authorization", "Bearer 42"), + ], + preserve_ua_header: false, + expected: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 42"), + ("cookie", "foo=bar"), + ("referer", "https://ramaproxy.org"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + }, + TestCase { + description: "all opt-in base headers defined, with custom header marker", + base_http_headers: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 1234567890"), + ("x-rama-custom-header-marker", "1"), + ("cookie", "session=1234567890"), + ("referer", "https://example.com"), + ], + original_http_header_order: Some(vec![ + "content-type", + "cookie", + "user-agent", + "referer", + "authorization", + ]), + original_headers: vec![ + ("content-type", "application/json"), + ("cookie", "foo=bar"), + ("user-agent", "php/8.0"), + ("referer", "https://ramaproxy.org"), + ("authorization", "Bearer 42"), + ], + preserve_ua_header: false, + expected: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 42"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ("cookie", "foo=bar"), + ("referer", "https://ramaproxy.org"), + ], + }, + TestCase { + description: "realistic browser example", + base_http_headers: vec![ + ("Host", "www.google.com"), + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "en-US,en;q=0.9"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Referer", "https://www.google.com/"), + ("Upgrade-Insecure-Requests", "1"), + ("x-rama-custom-header-marker", "1"), + ("Cookie", "rama-ua-test=1"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + original_http_header_order: Some(vec![ + "x-show-price", + "x-show-price-currency", + "accept-language", + "cookie", + ]), + original_headers: vec![ + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("accept-language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("cookie", "session=on; foo=bar"), + ("x-requested-with", "XMLHttpRequest"), + ], + preserve_ua_header: false, + expected: vec![ + ("Host", "www.google.com"), + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Upgrade-Insecure-Requests", "1"), + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("x-requested-with", "XMLHttpRequest"), + ("Cookie", "session=on; foo=bar"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + }, + ]; + + for test_case in test_cases { + let base_http_headers = + Http1HeaderMap::from_iter(test_case.base_http_headers.into_iter().map( + |(name, value)| { + ( + Http1HeaderName::from_str(name).unwrap(), + HeaderValue::from_static(value), + ) + }, + )); + let original_http_header_order = test_case.original_http_header_order.map(|headers| { + OriginalHttp1Headers::from_iter( + headers + .into_iter() + .map(|header| Http1HeaderName::from_str(header).unwrap()), + ) + }); + let original_headers = HeaderMap::from_iter( + test_case.original_headers.into_iter().map(|(name, value)| { + ( + HeaderName::from_static(name), + HeaderValue::from_static(value), + ) + }), + ); + let preserve_ua_header = test_case.preserve_ua_header; + + let output_headers = merge_http_headers( + &base_http_headers, + original_http_header_order, + original_headers, + preserve_ua_header, + ); + + let output_str = output_headers + .into_iter() + .map(|(name, value)| format!("{}: {}\r\n", name, value.to_str().unwrap())) + .join(""); + + let expected_str = test_case + .expected + .iter() + .map(|(name, value)| format!("{}: {}\r\n", name, value)) + .join(""); + + assert_eq!( + output_str, expected_str, + "test case '{}' failed", + test_case.description + ); + } + } + #[test] fn test_get_original_http_header_order() { struct TestCase { From 939e57a5063ed05dc40149b39b9eefee2b20fe78 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sat, 22 Feb 2025 15:04:17 +0100 Subject: [PATCH 27/39] be aware of more opt-in headers + sec-fetch headers only for secure reqs --- rama-ua/Cargo.toml | 4 +- rama-ua/src/emulate/service.rs | 104 ++++++++++++++++++++++++++++++++- rama-ua/src/ua/mod.rs | 2 +- rama-ua/src/ua/parse.rs | 19 +++++- 4 files changed, 122 insertions(+), 7 deletions(-) diff --git a/rama-ua/Cargo.toml b/rama-ua/Cargo.toml index 8dc15686..654cb608 100644 --- a/rama-ua/Cargo.toml +++ b/rama-ua/Cargo.toml @@ -15,14 +15,14 @@ workspace = true [features] default = [] -tls = ["dep:rama-net", "rama-net/tls"] +tls = ["rama-net/tls"] [dependencies] bytes = { workspace = true } itertools = { workspace = true } rama-core = { version = "0.2.0-alpha.7", path = "../rama-core" } rama-http-types = { version = "0.2.0-alpha.7", path = "../rama-http-types" } -rama-net = { version = "0.2.0-alpha.7", path = "../rama-net", optional = true } +rama-net = { version = "0.2.0-alpha.7", path = "../rama-net", features = ["http"] } rama-utils = { version = "0.2.0-alpha.7", path = "../rama-utils" } rand = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index c7c7d160..04d81630 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -7,7 +7,10 @@ use rama_core::{ use rama_http_types::{ HeaderMap, HeaderName, Method, Request, Version, conn::Http1ClientContextParams, - header::{ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_TYPE, COOKIE, REFERER, USER_AGENT}, + header::{ + ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, HOST, ORIGIN, + REFERER, USER_AGENT, + }, proto::{ h1::{ Http1HeaderMap, @@ -16,10 +19,11 @@ use rama_http_types::{ h2::PseudoHeaderOrder, }, }; +use rama_net::{Protocol, http::RequestContext}; use crate::{ CUSTOM_HEADER_MARKER, HttpAgent, RequestInitiator, UserAgent, UserAgentProfile, - contains_ignore_ascii_case, + contains_ignore_ascii_case, starts_with_ignore_ascii_case, }; use super::{UserAgentProvider, UserAgentSelectFallback}; @@ -229,11 +233,21 @@ where .map(|ua| ua.preserve_ua_header()) .unwrap_or_default(); + let is_secure_request = match ctx.get::() { + Some(request_ctx) => request_ctx.protocol.is_secure(), + None => req + .uri() + .scheme() + .map(|s| Protocol::from(s.clone()).is_secure()) + .unwrap_or_default(), + }; + let output_headers = merge_http_headers( base_http_headers, original_http_header_order, original_headers, preserve_ua_header, + is_secure_request, ); tracing::trace!( @@ -439,6 +453,7 @@ fn merge_http_headers( original_http_header_order: Option, original_headers: HeaderMap, preserve_ua_header: bool, + is_secure_request: bool, ) -> Http1HeaderMap { let mut original_headers = HeaderMapValueRemover::from(original_headers); @@ -456,7 +471,7 @@ fn merge_http_headers( let value = original_value.unwrap_or(base_value); output_headers_ref.push((base_name, value)); } - &REFERER | &COOKIE | &AUTHORIZATION => { + &REFERER | &COOKIE | &AUTHORIZATION | &HOST | &ORIGIN | &CONTENT_LENGTH => { if let Some(value) = original_value { output_headers_ref.push((base_name, value)); } @@ -472,6 +487,10 @@ fn merge_http_headers( _ => { if base_header_name == CUSTOM_HEADER_MARKER { output_headers_ref = &mut output_headers_b; + } else if starts_with_ignore_ascii_case(base_header_name.as_str(), "sec-fetch") { + if is_secure_request { + output_headers_ref.push((base_name, base_value)); + } } else { output_headers_ref.push((base_name, base_value)); } @@ -514,6 +533,7 @@ mod tests { original_http_header_order: Option>, original_headers: Vec<(&'static str, &'static str)>, preserve_ua_header: bool, + is_secure_request: bool, expected: Vec<(&'static str, &'static str)>, } @@ -524,6 +544,7 @@ mod tests { original_http_header_order: None, original_headers: vec![], preserve_ua_header: false, + is_secure_request: false, expected: vec![], }, TestCase { @@ -535,6 +556,7 @@ mod tests { original_http_header_order: None, original_headers: vec![], preserve_ua_header: false, + is_secure_request: false, expected: vec![ ("Accept", "text/html"), ("Content-Type", "application/json"), @@ -546,6 +568,7 @@ mod tests { original_http_header_order: None, original_headers: vec![("accept", "text/html")], preserve_ua_header: false, + is_secure_request: false, expected: vec![("accept", "text/html")], }, TestCase { @@ -554,6 +577,7 @@ mod tests { original_http_header_order: None, original_headers: vec![("content-type", "application/json")], preserve_ua_header: false, + is_secure_request: false, expected: vec![ ("accept", "text/html"), ("user-agent", "python/3.10"), @@ -573,6 +597,7 @@ mod tests { ("user-agent", "php/8.0"), ], preserve_ua_header: false, + is_secure_request: false, expected: vec![ ("accept", "text/html"), ("content-type", "application/json"), @@ -592,6 +617,7 @@ mod tests { ("user-agent", "php/8.0"), ], preserve_ua_header: true, + is_secure_request: false, expected: vec![ ("accept", "text/html"), ("content-type", "application/json"), @@ -612,6 +638,7 @@ mod tests { ("user-agent", "php/8.0"), ], preserve_ua_header: false, + is_secure_request: false, expected: vec![ ("accept", "text/html"), ("content-type", "application/json"), @@ -639,6 +666,7 @@ mod tests { ("referer", "https://ramaproxy.org"), ], preserve_ua_header: false, + is_secure_request: false, expected: vec![ ("accept", "text/html"), ("cookie", "foo=bar"), @@ -670,6 +698,7 @@ mod tests { ("authorization", "Bearer 42"), ], preserve_ua_header: false, + is_secure_request: false, expected: vec![ ("accept", "text/html"), ("authorization", "Bearer 42"), @@ -703,6 +732,7 @@ mod tests { ("authorization", "Bearer 42"), ], preserve_ua_header: false, + is_secure_request: false, expected: vec![ ("accept", "text/html"), ("authorization", "Bearer 42"), @@ -751,10 +781,76 @@ mod tests { ("accept-language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), ("cookie", "session=on; foo=bar"), ("x-requested-with", "XMLHttpRequest"), + ("host", "www.example.com"), ], preserve_ua_header: false, + is_secure_request: false, expected: vec![ + ("Host", "www.example.com"), + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Upgrade-Insecure-Requests", "1"), + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("x-requested-with", "XMLHttpRequest"), + ("Cookie", "session=on; foo=bar"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + }, + TestCase { + description: "realistic browser example over tls", + base_http_headers: vec![ ("Host", "www.google.com"), + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "en-US,en;q=0.9"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Referer", "https://www.google.com/"), + ("Upgrade-Insecure-Requests", "1"), + ("x-rama-custom-header-marker", "1"), + ("Cookie", "rama-ua-test=1"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + original_http_header_order: Some(vec![ + "x-show-price", + "x-show-price-currency", + "accept-language", + "cookie", + ]), + original_headers: vec![ + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("accept-language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("cookie", "session=on; foo=bar"), + ("x-requested-with", "XMLHttpRequest"), + ], + preserve_ua_header: false, + is_secure_request: true, + expected: vec![ ( "User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", @@ -808,12 +904,14 @@ mod tests { }), ); let preserve_ua_header = test_case.preserve_ua_header; + let is_secure_request = test_case.is_secure_request; let output_headers = merge_http_headers( &base_http_headers, original_http_header_order, original_headers, preserve_ua_header, + is_secure_request, ); let output_str = output_headers diff --git a/rama-ua/src/ua/mod.rs b/rama-ua/src/ua/mod.rs index 259f5dd9..e869d044 100644 --- a/rama-ua/src/ua/mod.rs +++ b/rama-ua/src/ua/mod.rs @@ -9,8 +9,8 @@ pub use info::{ }; mod parse; -pub(crate) use parse::contains_ignore_ascii_case; use parse::parse_http_user_agent_header; +pub(crate) use parse::{contains_ignore_ascii_case, starts_with_ignore_ascii_case}; /// Information that can be used to overwrite the [`UserAgent`] of an http request. /// diff --git a/rama-ua/src/ua/parse.rs b/rama-ua/src/ua/parse.rs index 13e57e5d..344b649b 100644 --- a/rama-ua/src/ua/parse.rs +++ b/rama-ua/src/ua/parse.rs @@ -199,6 +199,17 @@ pub(crate) fn contains_ignore_ascii_case(s: &str, sub: &str) -> Option { }) } +pub(crate) fn starts_with_ignore_ascii_case(s: &str, sub: &str) -> bool { + let n = sub.len(); + if n > s.len() { + return false; + } + + s.get(..n) + .map(|s| s.eq_ignore_ascii_case(sub)) + .unwrap_or_default() +} + fn contains_any_ignore_ascii_case(s: &str, subs: &[&str]) -> Option { let max = s.len(); let smallest_length = subs.iter().map(|s| s.len()).min().unwrap_or(0); @@ -228,7 +239,13 @@ fn contains_any_ignore_ascii_case(s: &str, subs: &[&str]) -> Option { #[cfg(test)] mod tests { - // test contains_ignore_ascii_case + use super::*; + + #[test] + fn test_starts_with_ignore_ascii_case() { + assert!(starts_with_ignore_ascii_case("user-agent", "user")); + assert!(!starts_with_ignore_ascii_case("user-agent", "agent")); + } #[test] fn test_contains_ignore_ascii_case_empty_sub() { From 96f74f440839fa7929bdd6d1e82fa5343810231a Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sat, 22 Feb 2025 21:23:12 +0100 Subject: [PATCH 28/39] support decompression request if no compresion was requested and make sure that we respect accept encoding value, as to make that kind of difficulty easier to handle --- rama-http-types/src/compression.rs | 5 ++ rama-http-types/src/lib.rs | 2 + rama-http-types/src/proto/h1/headers/map.rs | 5 ++ rama-http/src/layer/decompression/layer.rs | 20 +++++++ rama-http/src/layer/decompression/service.rs | 35 ++++++++++++ rama-ua/src/emulate/service.rs | 57 +++++++++++++++----- 6 files changed, 111 insertions(+), 13 deletions(-) create mode 100644 rama-http-types/src/compression.rs diff --git a/rama-http-types/src/compression.rs b/rama-http-types/src/compression.rs new file mode 100644 index 00000000..007869f1 --- /dev/null +++ b/rama-http-types/src/compression.rs @@ -0,0 +1,5 @@ +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)] +#[non_exhaustive] +/// Marker type that can be used to request an opt-in +/// decompression layer to decompress a body in case it is compressed. +pub struct DecompressIfPossible; diff --git a/rama-http-types/src/lib.rs b/rama-http-types/src/lib.rs index b20e95ba..af66afc5 100644 --- a/rama-http-types/src/lib.rs +++ b/rama-http-types/src/lib.rs @@ -35,6 +35,8 @@ pub use response::{IntoResponse, IntoResponseParts, Response}; pub mod proto; +pub mod compression; + pub mod headers; pub mod conn; diff --git a/rama-http-types/src/proto/h1/headers/map.rs b/rama-http-types/src/proto/h1/headers/map.rs index 05566a7e..d4789804 100644 --- a/rama-http-types/src/proto/h1/headers/map.rs +++ b/rama-http-types/src/proto/h1/headers/map.rs @@ -61,6 +61,11 @@ impl Http1HeaderMap { self.headers.get(key) } + #[inline] + pub fn contains_key(&self, key: impl AsHeaderName) -> bool { + self.headers.contains_key(key) + } + pub fn into_headers(self) -> HeaderMap { self.headers } diff --git a/rama-http/src/layer/decompression/layer.rs b/rama-http/src/layer/decompression/layer.rs index 775e3909..37e0d153 100644 --- a/rama-http/src/layer/decompression/layer.rs +++ b/rama-http/src/layer/decompression/layer.rs @@ -11,6 +11,7 @@ use rama_core::Layer; #[derive(Debug, Default, Clone)] pub struct DecompressionLayer { accept: AcceptEncoding, + only_if_requested: bool, } impl Layer for DecompressionLayer { @@ -20,6 +21,7 @@ impl Layer for DecompressionLayer { Decompression { inner: service, accept: self.accept, + only_if_requested: self.only_if_requested, } } } @@ -77,4 +79,22 @@ impl DecompressionLayer { self.accept.set_zstd(enable); self } + + /// Sets whether to only decompress bodies if it is requested + /// via the response extension or request context. + /// + /// A request is made using the [`rama_http_types::compression::DecompressIfPossible`] marker type. + pub fn only_if_requested(mut self, enable: bool) -> Self { + self.only_if_requested = enable; + self + } + + /// Sets whether to only decompress bodies if it is requested + /// via the response extension or request context. + /// + /// A request is made using the [`rama_http_types::compression::DecompressIfPossible`] marker type. + pub fn set_only_if_requested(&mut self, enable: bool) -> &mut Self { + self.only_if_requested = enable; + self + } } diff --git a/rama-http/src/layer/decompression/service.rs b/rama-http/src/layer/decompression/service.rs index cd403159..b132d3cf 100644 --- a/rama-http/src/layer/decompression/service.rs +++ b/rama-http/src/layer/decompression/service.rs @@ -11,6 +11,7 @@ use crate::{ header::{self, ACCEPT_ENCODING}, }; use rama_core::{Context, Service}; +use rama_http_types::compression::DecompressIfPossible; use rama_utils::macros::define_inner_service_accessors; /// Decompresses response bodies of the underlying service. @@ -22,6 +23,7 @@ use rama_utils::macros::define_inner_service_accessors; pub struct Decompression { pub(crate) inner: S, pub(crate) accept: AcceptEncoding, + pub(crate) only_if_requested: bool, } impl Decompression { @@ -30,6 +32,7 @@ impl Decompression { Self { inner: service, accept: AcceptEncoding::default(), + only_if_requested: false, } } @@ -82,6 +85,24 @@ impl Decompression { self.accept.set_zstd(enable); self } + + /// Sets whether to only decompress bodies if it is requested + /// via the response extension or request context. + /// + /// A request is made using the [`DecompressIfPossible`] marker type. + pub fn only_if_requested(mut self, enable: bool) -> Self { + self.only_if_requested = enable; + self + } + + /// Sets whether to only decompress bodies if it is requested + /// via the response extension or request context. + /// + /// A request is made using the [`DecompressIfPossible`] marker type. + pub fn set_only_if_requested(&mut self, enable: bool) -> &mut Self { + self.only_if_requested = enable; + self + } } impl fmt::Debug for Decompression { @@ -89,6 +110,7 @@ impl fmt::Debug for Decompression { f.debug_struct("Decompression") .field("inner", &self.inner) .field("accept", &self.accept) + .field("only_if_requested", &self.only_if_requested) .finish() } } @@ -98,6 +120,7 @@ impl Clone for Decompression { Decompression { inner: self.inner.clone(), accept: self.accept, + only_if_requested: self.only_if_requested, } } } @@ -123,10 +146,22 @@ where } } + let decompression_requested = ctx.contains::(); + let res = self.inner.serve(ctx, req).await?; let (mut parts, body) = res.into_parts(); + if self.only_if_requested + && !(decompression_requested + || parts.extensions.get::().is_some()) + { + return Ok(Response::from_parts( + parts, + DecompressionBody::new(BodyInner::identity(body)), + )); + } + let res = if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) { let body = match entry.get().as_bytes() { diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 04d81630..bc5fb794 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -5,11 +5,12 @@ use rama_core::{ error::{BoxError, ErrorContext, OpaqueError}, }; use rama_http_types::{ - HeaderMap, HeaderName, Method, Request, Version, + HeaderMap, HeaderName, IntoResponse, Method, Request, Response, Version, + compression::DecompressIfPossible, conn::Http1ClientContextParams, header::{ - ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, HOST, ORIGIN, - REFERER, USER_AGENT, + ACCEPT, ACCEPT_ENCODING, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, + COOKIE, HOST, ORIGIN, REFERER, USER_AGENT, }, proto::{ h1::{ @@ -148,10 +149,10 @@ impl Service> for UserAgentEmulateServic where State: Clone + Send + Sync + 'static, Body: Send + Sync + 'static, - S: Service, Error: Into>, + S: Service, Response: IntoResponse, Error: Into>, P: UserAgentProvider, { - type Response = S::Response; + type Response = Response; type Error = BoxError; async fn serve( @@ -190,7 +191,12 @@ where Some(profile) => profile, None => { return if self.optional { - self.inner.serve(ctx, req).await.map_err(Into::into) + Ok(self + .inner + .serve(ctx, req) + .await + .map_err(Into::into)? + .into_response()) } else { Err(OpaqueError::from_display( "requirement not fulfilled: user agent profile could not be selected", @@ -212,6 +218,8 @@ where Some(HttpAgent::Preserve), ); + let mut decompression_marker = None; + if preserve_http { tracing::trace!( ua_kind = %profile.ua_kind, @@ -242,6 +250,8 @@ where .unwrap_or_default(), }; + let requested_compression = original_headers.get(ACCEPT_ENCODING).is_some(); + let output_headers = merge_http_headers( base_http_headers, original_http_header_order, @@ -250,6 +260,10 @@ where is_secure_request, ); + if !requested_compression && output_headers.contains_key(ACCEPT_ENCODING) { + decompression_marker = Some(DecompressIfPossible::default()); + } + tracing::trace!( ua_kind = %profile.ua_kind, ua_version = ?profile.ua_version, @@ -285,7 +299,18 @@ where } // serve emulated http(s) request via inner service - self.inner.serve(ctx, req).await.map_err(Into::into) + let mut res = self + .inner + .serve(ctx, req) + .await + .map_err(Into::into)? + .into_response(); + + if let Some(marker) = decompression_marker { + res.extensions_mut().insert(marker); + } + + Ok(res) } } @@ -467,7 +492,7 @@ fn merge_http_headers( let base_header_name = base_name.header_name(); let original_value = original_headers.remove(base_header_name); match base_header_name { - &ACCEPT | &ACCEPT_LANGUAGE | &CONTENT_TYPE => { + &ACCEPT | &ACCEPT_LANGUAGE | &CONTENT_TYPE | &ACCEPT_ENCODING => { let value = original_value.unwrap_or(base_value); output_headers_ref.push((base_name, value)); } @@ -521,7 +546,9 @@ mod tests { use itertools::Itertools as _; use rama_core::service::service_fn; - use rama_http_types::{Body, HeaderValue, header::ETAG, proto::h1::Http1HeaderName}; + use rama_http_types::{ + Body, BodyExtractExt, HeaderValue, header::ETAG, proto::h1::Http1HeaderName, + }; use crate::{HttpHeadersProfile, HttpProfile}; @@ -1071,7 +1098,8 @@ mod tests { .body(Body::empty()) .unwrap(); let res = ua_service.serve(Context::default(), req).await.unwrap(); - assert_eq!(res, "navigate"); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, "navigate"); } #[tokio::test] @@ -1133,7 +1161,8 @@ mod tests { .body(Body::empty()) .unwrap(); let res = ua_service.serve(Context::default(), req).await.unwrap(); - assert_eq!(res, "xhr"); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, "xhr"); } #[tokio::test] @@ -1199,7 +1228,8 @@ mod tests { .body(Body::empty()) .unwrap(); let res = ua_service.serve(Context::default(), req).await.unwrap(); - assert_eq!(res, "fetch"); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, "fetch"); } #[tokio::test] @@ -1459,7 +1489,8 @@ mod tests { } let ctx = test_case.ctx.unwrap_or_default(); let res = ua_service.serve(ctx, req).await.unwrap(); - assert_eq!(res, test_case.expected, "{}", test_case.description); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, test_case.expected, "{}", test_case.description); } } } From 50db61e3c74d2f50e298389d41578c5b6fa0c050 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sat, 22 Feb 2025 21:45:39 +0100 Subject: [PATCH 29/39] differentiate between h1 and h2 headers turns out to be different as well --- rama-ua/src/emulate/service.rs | 552 ++++++++++++++++++++------------- rama-ua/src/profile/db.rs | 53 +++- rama-ua/src/profile/http.rs | 7 +- rama-ua/src/profile/ua.rs | 18 +- 4 files changed, 400 insertions(+), 230 deletions(-) diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index bc5fb794..7c71c773 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -23,8 +23,8 @@ use rama_http_types::{ use rama_net::{Protocol, http::RequestContext}; use crate::{ - CUSTOM_HEADER_MARKER, HttpAgent, RequestInitiator, UserAgent, UserAgentProfile, - contains_ignore_ascii_case, starts_with_ignore_ascii_case, + CUSTOM_HEADER_MARKER, HttpAgent, HttpHeadersProfile, RequestInitiator, UserAgent, + UserAgentProfile, contains_ignore_ascii_case, starts_with_ignore_ascii_case, }; use super::{UserAgentProvider, UserAgentSelectFallback}; @@ -229,50 +229,61 @@ where ); } else { emulate_http_settings(&mut ctx, &mut req, profile); - let base_http_headers = get_base_http_headers(&ctx, &req, profile); - let original_http_header_order = - get_original_http_header_order(&ctx, &req, self.input_header_order.as_ref()) + match get_base_http_headers(&ctx, &req, profile) { + Some(base_http_headers) => { + let original_http_header_order = get_original_http_header_order( + &ctx, + &req, + self.input_header_order.as_ref(), + ) .context("collect original http header order")?; - let original_headers = req.headers().clone(); - - let preserve_ua_header = ctx - .get::() - .map(|ua| ua.preserve_ua_header()) - .unwrap_or_default(); - - let is_secure_request = match ctx.get::() { - Some(request_ctx) => request_ctx.protocol.is_secure(), - None => req - .uri() - .scheme() - .map(|s| Protocol::from(s.clone()).is_secure()) - .unwrap_or_default(), - }; - - let requested_compression = original_headers.get(ACCEPT_ENCODING).is_some(); + let original_headers = req.headers().clone(); + + let preserve_ua_header = ctx + .get::() + .map(|ua| ua.preserve_ua_header()) + .unwrap_or_default(); + + let is_secure_request = match ctx.get::() { + Some(request_ctx) => request_ctx.protocol.is_secure(), + None => req + .uri() + .scheme() + .map(|s| Protocol::from(s.clone()).is_secure()) + .unwrap_or_default(), + }; + + let requested_compression = original_headers.get(ACCEPT_ENCODING).is_some(); + + let output_headers = merge_http_headers( + base_http_headers, + original_http_header_order, + original_headers, + preserve_ua_header, + is_secure_request, + ); - let output_headers = merge_http_headers( - base_http_headers, - original_http_header_order, - original_headers, - preserve_ua_header, - is_secure_request, - ); + if !requested_compression && output_headers.contains_key(ACCEPT_ENCODING) { + decompression_marker = Some(DecompressIfPossible::default()); + } - if !requested_compression && output_headers.contains_key(ACCEPT_ENCODING) { - decompression_marker = Some(DecompressIfPossible::default()); + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "user agent emulation: http settings and headers emulated" + ); + let (output_headers, original_headers) = output_headers.into_parts(); + *req.headers_mut() = output_headers; + req.extensions_mut().insert(original_headers); + } + None => { + tracing::debug!( + "user agent emulation: no http headers to emulate: no base http headers found" + ); + } } - - tracing::trace!( - ua_kind = %profile.ua_kind, - ua_version = ?profile.ua_version, - platform = ?profile.platform, - "user agent emulation: http settings and headers emulated" - ); - let (output_headers, original_headers) = output_headers.into_parts(); - *req.headers_mut() = output_headers; - req.extensions_mut().insert(original_headers); } #[cfg(feature = "tls")] @@ -321,29 +332,26 @@ fn emulate_http_settings( ) { match req.version() { Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => { - if let Some(h1) = &profile.http.h1 { - tracing::trace!( - ua_kind = %profile.ua_kind, - ua_version = ?profile.ua_version, - platform = ?profile.platform, - "UA emulation add http1-specific settings", - ); - ctx.insert(Http1ClientContextParams { - title_header_case: h1.title_case_headers, - }); - } + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "UA emulation add http1-specific settings", + ); + ctx.insert(Http1ClientContextParams { + title_header_case: profile.http.h1.title_case_headers, + }); } Version::HTTP_2 => { - if let Some(h2) = &profile.http.h2 { - tracing::trace!( - ua_kind = %profile.ua_kind, - ua_version = ?profile.ua_version, - platform = ?profile.platform, - "UA emulation add h2-specific settings", - ); - req.extensions_mut() - .insert(PseudoHeaderOrder::from_iter(h2.http_pseudo_headers.iter())); - } + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "UA emulation add h2-specific settings", + ); + req.extensions_mut().insert(PseudoHeaderOrder::from_iter( + profile.http.h2.http_pseudo_headers.iter(), + )); } Version::HTTP_3 => tracing::debug!( "UA emulation not yet supported for h3: not applying anything h3-specific" @@ -359,59 +367,73 @@ fn get_base_http_headers<'a, Body, State>( ctx: &Context, req: &Request, profile: &'a UserAgentProfile, -) -> &'a Http1HeaderMap { - match ctx.get::().and_then(|ua| ua.request_initiator()) { - Some(req_init) => { - tracing::trace!(%req_init, "base http headers defined based on hint from UserAgent (overwrite)"); - get_base_http_headers_from_req_init(req_init, profile) +) -> Option<&'a Http1HeaderMap> { + let headers_profile = match req.version() { + Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => &profile.http.h1.headers, + Version::HTTP_2 => &profile.http.h2.headers, + _ => { + tracing::debug!( + version = ?req.version(), + "UA emulation not supported for unknown http version: not applying anything version-specific", + ); + return None; } - // NOTE: the primitive checks below are pretty bad, - // feel free to help improve. Just need to make sure it has good enough fallbacks, - // and that they are cheap enough to check. - None => match *req.method() { - Method::GET => { - let req_init = if headers_contains_partial_value( - req.headers(), - &X_REQUESTED_WITH, - "XmlHttpRequest", - ) { - RequestInitiator::Xhr - } else { - RequestInitiator::Navigate - }; - tracing::trace!(%req_init, "base http headers defined based on Get=NavigateOrXhr assumption"); - get_base_http_headers_from_req_init(req_init, profile) - } - Method::POST => { - let req_init = if headers_contains_partial_value( - req.headers(), - &X_REQUESTED_WITH, - "XmlHttpRequest", - ) { - RequestInitiator::Xhr - } else if headers_contains_partial_value(req.headers(), &CONTENT_TYPE, "form-") { - RequestInitiator::Form - } else { - RequestInitiator::Fetch - }; - tracing::trace!(%req_init, "base http headers defined based on Post=FormOrFetch assumption"); - get_base_http_headers_from_req_init(req_init, profile) - } - _ => { - let req_init = if headers_contains_partial_value( - req.headers(), - &X_REQUESTED_WITH, - "XmlHttpRequest", - ) { - RequestInitiator::Xhr - } else { - RequestInitiator::Fetch - }; - tracing::trace!(%req_init, "base http headers defined based on XhrOrFetch assumption"); - get_base_http_headers_from_req_init(req_init, profile) + }; + Some( + match ctx.get::().and_then(|ua| ua.request_initiator()) { + Some(req_init) => { + tracing::trace!(%req_init, "base http headers defined based on hint from UserAgent (overwrite)"); + get_base_http_headers_from_req_init(req_init, headers_profile) } + // NOTE: the primitive checks below are pretty bad, + // feel free to help improve. Just need to make sure it has good enough fallbacks, + // and that they are cheap enough to check. + None => match *req.method() { + Method::GET => { + let req_init = if headers_contains_partial_value( + req.headers(), + &X_REQUESTED_WITH, + "XmlHttpRequest", + ) { + RequestInitiator::Xhr + } else { + RequestInitiator::Navigate + }; + tracing::trace!(%req_init, "base http headers defined based on Get=NavigateOrXhr assumption"); + get_base_http_headers_from_req_init(req_init, headers_profile) + } + Method::POST => { + let req_init = if headers_contains_partial_value( + req.headers(), + &X_REQUESTED_WITH, + "XmlHttpRequest", + ) { + RequestInitiator::Xhr + } else if headers_contains_partial_value(req.headers(), &CONTENT_TYPE, "form-") + { + RequestInitiator::Form + } else { + RequestInitiator::Fetch + }; + tracing::trace!(%req_init, "base http headers defined based on Post=FormOrFetch assumption"); + get_base_http_headers_from_req_init(req_init, headers_profile) + } + _ => { + let req_init = if headers_contains_partial_value( + req.headers(), + &X_REQUESTED_WITH, + "XmlHttpRequest", + ) { + RequestInitiator::Xhr + } else { + RequestInitiator::Fetch + }; + tracing::trace!(%req_init, "base http headers defined based on XhrOrFetch assumption"); + get_base_http_headers_from_req_init(req_init, headers_profile) + } + }, }, - } + ) } static X_REQUESTED_WITH: HeaderName = HeaderName::from_static("x-requested-with"); @@ -426,30 +448,21 @@ fn headers_contains_partial_value(headers: &HeaderMap, name: &HeaderName, value: fn get_base_http_headers_from_req_init( req_init: RequestInitiator, - profile: &UserAgentProfile, + headers: &HttpHeadersProfile, ) -> &Http1HeaderMap { match req_init { - RequestInitiator::Navigate => &profile.http.headers.navigate, - RequestInitiator::Form => profile - .http - .headers - .form - .as_ref() - .unwrap_or(&profile.http.headers.navigate), - RequestInitiator::Xhr => profile - .http - .headers + RequestInitiator::Navigate => &headers.navigate, + RequestInitiator::Form => headers.form.as_ref().unwrap_or(&headers.navigate), + RequestInitiator::Xhr => headers .xhr .as_ref() - .or(profile.http.headers.fetch.as_ref()) - .unwrap_or(&profile.http.headers.navigate), - RequestInitiator::Fetch => profile - .http - .headers + .or(headers.fetch.as_ref()) + .unwrap_or(&headers.navigate), + RequestInitiator::Fetch => headers .fetch .as_ref() - .or(profile.http.headers.xhr.as_ref()) - .unwrap_or(&profile.http.headers.navigate), + .or(headers.xhr.as_ref()) + .unwrap_or(&headers.navigate), } } @@ -550,7 +563,7 @@ mod tests { Body, BodyExtractExt, HeaderValue, header::ETAG, proto::h1::Http1HeaderName, }; - use crate::{HttpHeadersProfile, HttpProfile}; + use crate::{Http1Profile, Http2Profile, HttpHeadersProfile, HttpProfile}; #[test] fn test_merge_http_headers() { @@ -1049,6 +1062,77 @@ mod tests { } } + #[tokio::test] + async fn test_get_base_h2_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + title_case_headers: false, + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: None, + xhr: None, + form: None, + }, + http_pseudo_headers: vec![], + }, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .map(|header| header.to_str().unwrap().to_owned()) + .unwrap_or_default(), + ) + }), + ua_profile, + ); + + let req = Request::builder() + .method(Method::GET) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, ""); + + let req = Request::builder() + .method(Method::GET) + .version(Version::HTTP_2) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, "navigate"); + } + #[tokio::test] async fn test_get_base_http_headers_profile_with_only_navigate_headers() { let ua = UserAgent::new( @@ -1059,19 +1143,29 @@ mod tests { ua_version: ua.ua_version(), platform: ua.platform(), http: HttpProfile { - headers: HttpHeadersProfile { - navigate: Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("navigate").unwrap())] - .into_iter() - .collect(), - None, - ), - xhr: None, - fetch: None, - form: None, + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + xhr: None, + fetch: None, + form: None, + }, + title_case_headers: false, + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + http_pseudo_headers: vec![], }, - h1: None, - h2: None, }, #[cfg(feature = "tls")] tls: crate::TlsProfile { @@ -1112,29 +1206,39 @@ mod tests { ua_version: ua.ua_version(), platform: ua.platform(), http: HttpProfile { - headers: HttpHeadersProfile { - navigate: Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("navigate").unwrap())] - .into_iter() - .collect(), - None, - ), - xhr: Some(Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("xhr").unwrap())] - .into_iter() - .collect(), - None, - )), - fetch: None, - form: Some(Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("form").unwrap())] - .into_iter() - .collect(), - None, - )), + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + xhr: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("xhr").unwrap())] + .into_iter() + .collect(), + None, + )), + fetch: None, + form: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("form").unwrap())] + .into_iter() + .collect(), + None, + )), + }, + title_case_headers: false, + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + http_pseudo_headers: vec![], }, - h1: None, - h2: None, }, #[cfg(feature = "tls")] tls: crate::TlsProfile { @@ -1175,29 +1279,39 @@ mod tests { ua_version: ua.ua_version(), platform: ua.platform(), http: HttpProfile { - headers: HttpHeadersProfile { - navigate: Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("navigate").unwrap())] - .into_iter() - .collect(), - None, - ), - fetch: Some(Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("fetch").unwrap())] - .into_iter() - .collect(), - None, - )), - xhr: None, - form: Some(Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("form").unwrap())] - .into_iter() - .collect(), - None, - )), + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("fetch").unwrap())] + .into_iter() + .collect(), + None, + )), + xhr: None, + form: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("form").unwrap())] + .into_iter() + .collect(), + None, + )), + }, + title_case_headers: false, + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + http_pseudo_headers: vec![], }, - h1: None, - h2: None, }, #[cfg(feature = "tls")] tls: crate::TlsProfile { @@ -1242,34 +1356,44 @@ mod tests { ua_version: ua.ua_version(), platform: ua.platform(), http: HttpProfile { - headers: HttpHeadersProfile { - navigate: Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("navigate").unwrap())] - .into_iter() - .collect(), - None, - ), - fetch: Some(Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("fetch").unwrap())] - .into_iter() - .collect(), - None, - )), - xhr: Some(Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("xhr").unwrap())] - .into_iter() - .collect(), - None, - )), - form: Some(Http1HeaderMap::new( - [(ETAG, HeaderValue::from_str("form").unwrap())] - .into_iter() - .collect(), - None, - )), + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("fetch").unwrap())] + .into_iter() + .collect(), + None, + )), + xhr: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("xhr").unwrap())] + .into_iter() + .collect(), + None, + )), + form: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("form").unwrap())] + .into_iter() + .collect(), + None, + )), + }, + title_case_headers: false, + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + http_pseudo_headers: vec![], }, - h1: None, - h2: None, }, #[cfg(feature = "tls")] tls: crate::TlsProfile { diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index 1e1e95c1..96aabb5c 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -223,6 +223,19 @@ mod tests { assert_eq!( profile .http + .h1 + .headers + .navigate + .get(USER_AGENT) + .unwrap() + .to_str() + .unwrap(), + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0" + ); + assert_eq!( + profile + .http + .h2 .headers .navigate .get(USER_AGENT) @@ -360,6 +373,7 @@ mod tests { let rnd = db.rnd().unwrap(); set.insert( rnd.http + .h1 .headers .navigate .get(USER_AGENT) @@ -380,19 +394,34 @@ mod tests { ua_version: ua.ua_version(), platform: ua.platform(), http: crate::HttpProfile { - headers: crate::HttpHeadersProfile { - navigate: Http1HeaderMap::new( - [(USER_AGENT, HeaderValue::from_str(s).unwrap())] - .into_iter() - .collect(), - None, - ), - fetch: None, - xhr: None, - form: None, + h1: crate::Http1Profile { + headers: crate::HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(USER_AGENT, HeaderValue::from_str(s).unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: None, + xhr: None, + form: None, + }, + title_case_headers: false, + }, + h2: crate::Http2Profile { + headers: crate::HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(USER_AGENT, HeaderValue::from_str(s).unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: None, + xhr: None, + form: None, + }, + http_pseudo_headers: vec![], }, - h1: None, - h2: None, }, #[cfg(feature = "tls")] tls: crate::TlsProfile { diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs index e2ec24c3..9a231184 100644 --- a/rama-ua/src/profile/http.rs +++ b/rama-ua/src/profile/http.rs @@ -9,9 +9,8 @@ pub static CUSTOM_HEADER_MARKER: HeaderName = #[derive(Debug, Clone, Deserialize, Serialize)] pub struct HttpProfile { - pub headers: HttpHeadersProfile, - pub h1: Option, - pub h2: Option, + pub h1: Http1Profile, + pub h2: Http2Profile, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -24,10 +23,12 @@ pub struct HttpHeadersProfile { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Http1Profile { + pub headers: HttpHeadersProfile, pub title_case_headers: bool, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Http2Profile { + pub headers: HttpHeadersProfile, pub http_pseudo_headers: Vec, } diff --git a/rama-ua/src/profile/ua.rs b/rama-ua/src/profile/ua.rs index c53859ac..824217b3 100644 --- a/rama-ua/src/profile/ua.rs +++ b/rama-ua/src/profile/ua.rs @@ -22,10 +22,26 @@ pub struct UserAgentProfile { impl UserAgentProfile { pub fn ua_str(&self) -> Option<&str> { - self.http + if let Some(ua) = self + .http + .h1 .headers .navigate .get(USER_AGENT) .and_then(|v| v.to_str().ok()) + { + Some(ua) + } else if let Some(ua) = self + .http + .h2 + .headers + .navigate + .get(USER_AGENT) + .and_then(|v| v.to_str().ok()) + { + Some(ua) + } else { + None + } } } From 2cc590844658d1503852a599ea3401f4afee58b7 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sun, 23 Feb 2025 22:13:24 +0100 Subject: [PATCH 30/39] refactor q-value and cleanup http deps unrelated from the PR but seems like I kinda got into a refactor dfrift, whoops --- Cargo.lock | 4 - .../src/headers/common/accept.rs | 23 +- .../src/headers/common/mod.rs | 0 .../src/headers/encoding/accept_encoding.rs | 99 +++++++ .../src/headers/encoding/mod.rs | 263 +++++++++--------- rama-http-types/src/headers/mod.rs | 8 + rama-http-types/src/headers/specifier/mod.rs | 6 + .../src/headers/specifier}/quality_value.rs | 97 +++++-- .../src/headers/util/csv.rs | 4 +- rama-http-types/src/headers/util/mod.rs | 1 + rama-http/Cargo.toml | 4 - .../headers/forwarded/exotic_forward_ip.rs | 2 +- rama-http/src/headers/mod.rs | 18 +- rama-http/src/headers/util/mod.rs | 3 - rama-http/src/io/request.rs | 2 +- rama-http/src/io/response.rs | 2 +- rama-http/src/layer/auth/add_authorization.rs | 13 +- .../layer/auth/async_require_authorization.rs | 2 +- rama-http/src/layer/body_limit.rs | 5 +- rama-http/src/layer/catch_panic.rs | 2 +- .../classify/status_in_range_is_error.rs | 2 +- rama-http/src/layer/compression/body.rs | 4 +- rama-http/src/layer/compression/layer.rs | 3 +- rama-http/src/layer/compression/mod.rs | 8 +- rama-http/src/layer/compression/predicate.rs | 16 +- rama-http/src/layer/compression/service.rs | 7 +- rama-http/src/layer/cors/tests.rs | 4 +- rama-http/src/layer/decompression/layer.rs | 2 +- .../src/layer/decompression/request/layer.rs | 2 +- .../layer/decompression/request/service.rs | 6 +- rama-http/src/layer/decompression/service.rs | 8 +- rama-http/src/layer/opentelemetry.rs | 10 +- .../src/layer/required_header/request.rs | 8 +- rama-http/src/layer/retry/body.rs | 1 + rama-http/src/layer/retry/policy.rs | 6 +- rama-http/src/layer/trace/body.rs | 2 +- rama-http/src/layer/trace/on_response.rs | 2 +- .../src/layer/traffic_writer/response.rs | 2 +- rama-http/src/layer/ua.rs | 2 +- rama-http/src/layer/util/compression.rs | 86 +----- rama-http/src/layer/util/mod.rs | 2 - rama-http/src/matcher/mod.rs | 35 +-- rama-http/src/service/fs/mod.rs | 2 +- rama-http/src/service/fs/serve_dir/future.rs | 8 +- rama-http/src/service/fs/serve_dir/mod.rs | 10 +- .../src/service/fs/serve_dir/open_file.rs | 21 +- rama-http/src/service/fs/serve_dir/tests.rs | 2 +- rama-http/src/service/fs/serve_file.rs | 2 +- .../service/web/endpoint/extract/authority.rs | 6 +- .../service/web/endpoint/extract/body/csv.rs | 24 +- .../service/web/endpoint/extract/body/json.rs | 18 +- .../src/service/web/endpoint/extract/host.rs | 6 +- .../web/endpoint/extract/typed_header.rs | 2 +- 53 files changed, 481 insertions(+), 396 deletions(-) rename {rama-http => rama-http-types}/src/headers/common/accept.rs (89%) rename {rama-http => rama-http-types}/src/headers/common/mod.rs (100%) create mode 100644 rama-http-types/src/headers/encoding/accept_encoding.rs rename rama-http/src/layer/util/content_encoding.rs => rama-http-types/src/headers/encoding/mod.rs (64%) create mode 100644 rama-http-types/src/headers/specifier/mod.rs rename {rama-http/src/headers/util => rama-http-types/src/headers/specifier}/quality_value.rs (74%) rename {rama-http => rama-http-types}/src/headers/util/csv.rs (89%) create mode 100644 rama-http-types/src/headers/util/mod.rs delete mode 100644 rama-http/src/headers/util/mod.rs diff --git a/Cargo.lock b/Cargo.lock index aa08d9f9..25fe2070 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2289,10 +2289,6 @@ dependencies = [ "csv", "flate2", "futures-lite", - "headers", - "http", - "http-body", - "http-body-util", "http-range-header", "httpdate", "iri-string", diff --git a/rama-http/src/headers/common/accept.rs b/rama-http-types/src/headers/common/accept.rs similarity index 89% rename from rama-http/src/headers/common/accept.rs rename to rama-http-types/src/headers/common/accept.rs index caedd494..cdd4300c 100644 --- a/rama-http/src/headers/common/accept.rs +++ b/rama-http-types/src/headers/common/accept.rs @@ -1,5 +1,6 @@ use crate::dep::mime::{self, Mime}; -use crate::headers::{self, Header, QualityValue}; +use crate::headers::specifier::QualityValue; +use crate::headers::{self, Header}; use crate::{HeaderName, HeaderValue}; use std::iter::FromIterator; @@ -35,10 +36,10 @@ fn qitem(mime: Mime) -> QualityValue { /// # Examples /// ``` /// use std::iter::FromIterator; -/// use rama_http::headers::{Accept, QualityValue, HeaderMapExt}; -/// use rama_http::dep::mime; +/// use rama_http_types::headers::{Accept, specifier::QualityValue, HeaderMapExt}; +/// use rama_http_types::dep::mime; /// -/// let mut headers = rama_http::HeaderMap::new(); +/// let mut headers = rama_http_types::HeaderMap::new(); /// /// headers.typed_insert( /// Accept::from_iter(vec![ @@ -49,10 +50,10 @@ fn qitem(mime: Mime) -> QualityValue { /// /// ``` /// use std::iter::FromIterator; -/// use rama_http::headers::{Accept, QualityValue, HeaderMapExt}; -/// use rama_http::dep::mime; +/// use rama_http_types::headers::{Accept, specifier::QualityValue, HeaderMapExt}; +/// use rama_http_types::dep::mime; /// -/// let mut headers = rama_http::HeaderMap::new(); +/// let mut headers = rama_http_types::HeaderMap::new(); /// headers.typed_insert( /// Accept::from_iter(vec![ /// QualityValue::new(mime::APPLICATION_JSON, Default::default()), @@ -61,10 +62,10 @@ fn qitem(mime: Mime) -> QualityValue { /// ``` /// ``` /// use std::iter::FromIterator; -/// use rama_http::headers::{Accept, QualityValue, HeaderMapExt}; -/// use rama_http::dep::mime; +/// use rama_http_types::headers::{Accept, specifier::QualityValue, HeaderMapExt}; +/// use rama_http_types::dep::mime; /// -/// let mut headers = rama_http::HeaderMap::new(); +/// let mut headers = rama_http_types::HeaderMap::new(); /// /// headers.typed_insert( /// Accept::from_iter(vec![ @@ -161,7 +162,7 @@ mod tests { mime::{TEXT_HTML, TEXT_PLAIN, TEXT_PLAIN_UTF_8}, }; - use crate::headers::Quality; + use crate::headers::specifier::Quality; macro_rules! test_header { ($name: ident, $input: expr, $expected: expr) => { diff --git a/rama-http/src/headers/common/mod.rs b/rama-http-types/src/headers/common/mod.rs similarity index 100% rename from rama-http/src/headers/common/mod.rs rename to rama-http-types/src/headers/common/mod.rs diff --git a/rama-http-types/src/headers/encoding/accept_encoding.rs b/rama-http-types/src/headers/encoding/accept_encoding.rs new file mode 100644 index 00000000..8d1ac16b --- /dev/null +++ b/rama-http-types/src/headers/encoding/accept_encoding.rs @@ -0,0 +1,99 @@ +use super::SupportedEncodings; +use crate::HeaderValue; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AcceptEncoding { + gzip: bool, + deflate: bool, + br: bool, + zstd: bool, +} + +impl AcceptEncoding { + pub fn maybe_to_header_value(self) -> Option { + let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) { + (true, true, true, false) => "gzip,deflate,br", + (true, true, false, false) => "gzip,deflate", + (true, false, true, false) => "gzip,br", + (true, false, false, false) => "gzip", + (false, true, true, false) => "deflate,br", + (false, true, false, false) => "deflate", + (false, false, true, false) => "br", + (true, true, true, true) => "zstd,gzip,deflate,br", + (true, true, false, true) => "zstd,gzip,deflate", + (true, false, true, true) => "zstd,gzip,br", + (true, false, false, true) => "zstd,gzip", + (false, true, true, true) => "zstd,deflate,br", + (false, true, false, true) => "zstd,deflate", + (false, false, true, true) => "zstd,br", + (false, false, false, true) => "zstd", + (false, false, false, false) => return None, + }; + Some(HeaderValue::from_static(accept)) + } + + pub fn set_gzip(&mut self, enable: bool) { + self.gzip = enable; + } + + pub fn with_gzip(mut self, enable: bool) -> Self { + self.gzip = enable; + self + } + + pub fn set_deflate(&mut self, enable: bool) { + self.deflate = enable; + } + + pub fn with_deflate(mut self, enable: bool) -> Self { + self.deflate = enable; + self + } + + pub fn set_br(&mut self, enable: bool) { + self.br = enable; + } + + pub fn with_br(mut self, enable: bool) -> Self { + self.br = enable; + self + } + + pub fn set_zstd(&mut self, enable: bool) { + self.zstd = enable; + } + + pub fn with_zstd(mut self, enable: bool) -> Self { + self.zstd = enable; + self + } +} + +impl SupportedEncodings for AcceptEncoding { + fn gzip(&self) -> bool { + self.gzip + } + + fn deflate(&self) -> bool { + self.deflate + } + + fn br(&self) -> bool { + self.br + } + + fn zstd(&self) -> bool { + self.zstd + } +} + +impl Default for AcceptEncoding { + fn default() -> Self { + AcceptEncoding { + gzip: true, + deflate: true, + br: true, + zstd: true, + } + } +} diff --git a/rama-http/src/layer/util/content_encoding.rs b/rama-http-types/src/headers/encoding/mod.rs similarity index 64% rename from rama-http/src/layer/util/content_encoding.rs rename to rama-http-types/src/headers/encoding/mod.rs index 7442191a..57bf9b96 100644 --- a/rama-http/src/layer/util/content_encoding.rs +++ b/rama-http-types/src/headers/encoding/mod.rs @@ -1,18 +1,41 @@ -//! Types and functions for handling content encoding. +//! Utility types and functions that can be used in context of encoding headers. use rama_utils::macros::match_ignore_ascii_case_str; +use std::fmt; -pub(crate) trait SupportedEncodings: Copy { +mod accept_encoding; +pub use accept_encoding::AcceptEncoding; + +use super::specifier::{Quality, QualityValue}; + +pub trait SupportedEncodings: Copy { fn gzip(&self) -> bool; fn deflate(&self) -> bool; fn br(&self) -> bool; fn zstd(&self) -> bool; } -// This enum's variants are ordered from least to most preferred. +impl SupportedEncodings for bool { + fn gzip(&self) -> bool { + *self + } + + fn deflate(&self) -> bool { + *self + } + + fn br(&self) -> bool { + *self + } + + fn zstd(&self) -> bool { + *self + } +} + #[derive(Copy, Clone, Debug, Ord, PartialOrd, PartialEq, Eq, Hash)] -pub(crate) enum Encoding { - #[allow(dead_code)] +/// This enum's variants are ordered from least to most preferred. +pub enum Encoding { Identity, Deflate, Gzip, @@ -20,9 +43,21 @@ pub(crate) enum Encoding { Zstd, } +impl fmt::Display for Encoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl From for http::HeaderValue { + #[inline] + fn from(encoding: Encoding) -> Self { + http::HeaderValue::from_static(encoding.as_str()) + } +} + impl Encoding { - #[allow(dead_code)] - fn to_str(self) -> &'static str { + fn as_str(self) -> &'static str { match self { Encoding::Identity => "identity", Encoding::Gzip => "gzip", @@ -32,8 +67,7 @@ impl Encoding { } } - #[allow(dead_code)] - pub(crate) fn to_file_extension(self) -> Option<&'static std::ffi::OsStr> { + pub fn to_file_extension(self) -> Option<&'static std::ffi::OsStr> { match self { Encoding::Gzip => Some(std::ffi::OsStr::new(".gz")), Encoding::Deflate => Some(std::ffi::OsStr::new(".zz")), @@ -43,122 +77,72 @@ impl Encoding { } } - #[allow(dead_code)] - pub(crate) fn into_header_value(self) -> http::HeaderValue { - http::HeaderValue::from_static(self.to_str()) - } - - fn parse(s: &str, _supported_encoding: impl SupportedEncodings) -> Option { + fn parse(s: &str, supported_encoding: impl SupportedEncodings) -> Option { match_ignore_ascii_case_str! { match (s) { - "gzip" | "x-gzip" if _supported_encoding.gzip() => Some(Encoding::Gzip), - "deflate" if _supported_encoding.deflate() => Some(Encoding::Deflate), - "br" if _supported_encoding.br() => Some(Encoding::Brotli), - "zstd" if _supported_encoding.zstd() => Some(Encoding::Zstd), + "gzip" | "x-gzip" if supported_encoding.gzip() => Some(Encoding::Gzip), + "deflate" if supported_encoding.deflate() => Some(Encoding::Deflate), + "br" if supported_encoding.br() => Some(Encoding::Brotli), + "zstd" if supported_encoding.zstd() => Some(Encoding::Zstd), "identity" => Some(Encoding::Identity), _ => None, } } } - #[cfg(feature = "compression")] - // based on https://github.com/http-rs/accept-encoding - #[allow(dead_code)] - pub(crate) fn from_headers( + pub fn maybe_from_content_encoding_header( + headers: &http::HeaderMap, + supported_encoding: impl SupportedEncodings, + ) -> Option { + headers + .get(http::header::CONTENT_ENCODING) + .and_then(|hval| hval.to_str().ok()) + .and_then(|s| Encoding::parse(s, supported_encoding)) + } + + #[inline] + pub fn from_content_encoding_header( headers: &http::HeaderMap, supported_encoding: impl SupportedEncodings, ) -> Self { - Encoding::preferred_encoding(encodings(headers, supported_encoding)) + Encoding::maybe_from_content_encoding_header(headers, supported_encoding) .unwrap_or(Encoding::Identity) } - #[allow(dead_code)] - pub(crate) fn preferred_encoding( - accepted_encodings: impl Iterator, + pub fn maybe_from_accept_encoding_headers( + headers: &http::HeaderMap, + supported_encoding: impl SupportedEncodings, ) -> Option { - accepted_encodings - .filter(|(_, qvalue)| qvalue.0 > 0) - .max_by_key(|&(encoding, qvalue)| (qvalue, encoding)) - .map(|(encoding, _)| encoding) + Encoding::maybe_preferred_encoding(parse_accept_encoding_headers( + headers, + supported_encoding, + )) } -} - -// Allowed q-values are numbers between 0 and 1 with at most 3 digits in the fractional part. They -// are presented here as an unsigned integer between 0 and 1000. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) struct QValue(u16); -impl QValue { #[inline] - fn one() -> Self { - Self(1000) - } - - // Parse a q-value as specified in RFC 7231 section 5.3.1. - fn parse(s: &str) -> Option { - let mut c = s.chars(); - // Parse "q=" (case-insensitively). - match c.next() { - Some('q' | 'Q') => (), - _ => return None, - }; - match c.next() { - Some('=') => (), - _ => return None, - }; - - // Parse leading digit. Since valid q-values are between 0.000 and 1.000, only "0" and "1" - // are allowed. - let mut value = match c.next() { - Some('0') => 0, - Some('1') => 1000, - _ => return None, - }; - - // Parse optional decimal point. - match c.next() { - Some('.') => (), - None => return Some(Self(value)), - _ => return None, - }; - - // Parse optional fractional digits. The value of each digit is multiplied by `factor`. - // Since the q-value is represented as an integer between 0 and 1000, `factor` is `100` for - // the first digit, `10` for the next, and `1` for the digit after that. - let mut factor = 100; - loop { - match c.next() { - Some(n @ '0'..='9') => { - // If `factor` is less than `1`, three digits have already been parsed. A - // q-value having more than 3 fractional digits is invalid. - if factor < 1 { - return None; - } - // Add the digit's value multiplied by `factor` to `value`. - value += factor * (n as u16 - '0' as u16); - } - None => { - // No more characters to parse. Check that the value representing the q-value is - // in the valid range. - return if value <= 1000 { - Some(Self(value)) - } else { - None - }; - } - _ => return None, - }; - factor /= 10; - } + pub fn from_accept_encoding_headers( + headers: &http::HeaderMap, + supported_encoding: impl SupportedEncodings, + ) -> Self { + Encoding::maybe_from_accept_encoding_headers(headers, supported_encoding) + .unwrap_or(Encoding::Identity) + } + + pub fn maybe_preferred_encoding( + accepted_encodings: impl Iterator>, + ) -> Option { + accepted_encodings + .filter(|qval| qval.quality.as_u16() > 0) + .max_by_key(|qval| (qval.quality, qval.value)) + .map(|qval| qval.value) } } -// based on https://github.com/http-rs/accept-encoding -#[allow(dead_code)] -pub(crate) fn encodings<'a>( +/// based on https://github.com/http-rs/accept-encoding +pub fn parse_accept_encoding_headers<'a>( headers: &'a http::HeaderMap, supported_encoding: impl SupportedEncodings + 'a, -) -> impl Iterator + 'a { +) -> impl Iterator> + 'a { headers .get_all(http::header::ACCEPT_ENCODING) .iter() @@ -173,16 +157,16 @@ pub(crate) fn encodings<'a>( }; let qval = if let Some(qval) = v.next() { - QValue::parse(qval.trim())? + qval.trim().parse::().ok()? } else { - QValue::one() + Quality::one() }; - Some((encoding, qval)) + Some(QualityValue::new(encoding, qval)) }) } -#[cfg(all(test, feature = "compression"))] +#[cfg(test)] mod tests { use super::*; @@ -209,7 +193,8 @@ mod tests { #[test] fn no_accept_encoding_header() { - let encoding = Encoding::from_headers(&http::HeaderMap::new(), SupportedEncodingsAll); + let encoding = + Encoding::from_accept_encoding_headers(&http::HeaderMap::new(), SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); } @@ -220,7 +205,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -231,7 +216,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -242,7 +227,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,x-gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -253,7 +238,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("deflate,x-gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -264,7 +249,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,deflate,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -275,7 +260,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -286,7 +271,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,deflate,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -301,7 +286,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -316,7 +301,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -335,7 +320,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -346,7 +331,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br;q=0.8"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); let mut headers = http::HeaderMap::new(); @@ -354,7 +339,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.8,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); @@ -362,7 +347,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.995,br;q=0.999"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -373,7 +358,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,deflate;q=0.6,br;q=0.8"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); let mut headers = http::HeaderMap::new(); @@ -381,7 +366,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.8,deflate;q=0.6,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); @@ -389,7 +374,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.6,deflate;q=0.8,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Deflate, encoding); let mut headers = http::HeaderMap::new(); @@ -397,7 +382,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.995,deflate;q=0.997,br;q=0.999"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -408,7 +393,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("invalid,gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -419,7 +404,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -427,7 +412,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0."), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -435,7 +420,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -446,7 +431,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gZiP"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); @@ -454,7 +439,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br;Q=0.8"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -465,7 +450,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static(" gzip\t; q=0.5 ,\tbr ;\tq=0.8\t"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -476,7 +461,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q =0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -484,7 +469,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q= 0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); } @@ -495,7 +480,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=-0.1"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -503,7 +488,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=00.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -511,7 +496,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5000"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -519,7 +504,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -527,7 +512,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=1.01"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -535,7 +520,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=1.001"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); } } diff --git a/rama-http-types/src/headers/mod.rs b/rama-http-types/src/headers/mod.rs index 13ea70b8..5b78013d 100644 --- a/rama-http-types/src/headers/mod.rs +++ b/rama-http-types/src/headers/mod.rs @@ -90,3 +90,11 @@ pub mod authorization { mod ext; #[doc(inline)] pub use ext::HeaderExt; + +pub mod encoding; +pub mod specifier; + +pub mod util; + +mod common; +pub use common::Accept; diff --git a/rama-http-types/src/headers/specifier/mod.rs b/rama-http-types/src/headers/specifier/mod.rs new file mode 100644 index 00000000..aca3af97 --- /dev/null +++ b/rama-http-types/src/headers/specifier/mod.rs @@ -0,0 +1,6 @@ +//! Specifiers that can be used as part of header values. +//! +//! An example is the [`QValue`] used in function of several headers such as 'accept-encoding'. + +mod quality_value; +pub use quality_value::{Quality, QualityValue}; diff --git a/rama-http/src/headers/util/quality_value.rs b/rama-http-types/src/headers/specifier/quality_value.rs similarity index 74% rename from rama-http/src/headers/util/quality_value.rs rename to rama-http-types/src/headers/specifier/quality_value.rs index c0328b13..bbb22ebb 100644 --- a/rama-http/src/headers/util/quality_value.rs +++ b/rama-http-types/src/headers/specifier/quality_value.rs @@ -24,6 +24,80 @@ use self::internal::IntoQuality; #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)] pub struct Quality(u16); +impl Quality { + #[inline] + pub fn one() -> Self { + Self(1000) + } + + #[inline] + pub fn as_u16(&self) -> u16 { + self.0 + } +} + +impl str::FromStr for Quality { + type Err = crate::headers::Error; + + // Parse a q-value as specified in RFC 7231 section 5.3.1. + fn from_str(s: &str) -> Result { + let mut c = s.chars(); + // Parse "q=" (case-insensitively). + match c.next() { + Some('q' | 'Q') => (), + _ => return Err(crate::headers::Error::invalid()), + }; + match c.next() { + Some('=') => (), + _ => return Err(crate::headers::Error::invalid()), + }; + + // Parse leading digit. Since valid q-values are between 0.000 and 1.000, only "0" and "1" + // are allowed. + let mut value = match c.next() { + Some('0') => 0, + Some('1') => 1000, + _ => return Err(crate::headers::Error::invalid()), + }; + + // Parse optional decimal point. + match c.next() { + Some('.') => (), + None => return Ok(Self(value)), + _ => return Err(crate::headers::Error::invalid()), + }; + + // Parse optional fractional digits. The value of each digit is multiplied by `factor`. + // Since the q-value is represented as an integer between 0 and 1000, `factor` is `100` for + // the first digit, `10` for the next, and `1` for the digit after that. + let mut factor = 100; + loop { + match c.next() { + Some(n @ '0'..='9') => { + // If `factor` is less than `1`, three digits have already been parsed. A + // q-value having more than 3 fractional digits is invalid. + if factor < 1 { + return Err(crate::headers::Error::invalid()); + } + // Add the digit's value multiplied by `factor` to `value`. + value += factor * (n as u16 - '0' as u16); + } + None => { + // No more characters to parse. Check that the value representing the q-value is + // in the valid range. + return if value <= 1000 { + Ok(Self(value)) + } else { + Err(crate::headers::Error::invalid()) + }; + } + _ => return Err(crate::headers::Error::invalid()), + }; + factor /= 10; + } + } +} + impl Default for Quality { fn default() -> Quality { Quality(1000) @@ -40,6 +114,8 @@ pub struct QualityValue { pub quality: Quality, } +impl Copy for QualityValue {} + impl QualityValue { /// Creates a new `QualityValue` from an item and a quality. pub const fn new(value: T, quality: Quality) -> QualityValue { @@ -92,7 +168,7 @@ impl str::FromStr for QualityValue { fn from_str(s: &str) -> Result, crate::headers::Error> { // Set defaults used if parsing fails. let mut raw_item = s; - let mut quality = 1f32; + let mut quality = Quality::one(); let mut parts = s.rsplitn(2, ';').map(|x| x.trim()); if let (Some(first), Some(second), None) = (parts.next(), parts.next(), parts.next()) { @@ -100,26 +176,13 @@ impl str::FromStr for QualityValue { return Err(crate::headers::Error::invalid()); } if first.starts_with("q=") || first.starts_with("Q=") { - let q_part = &first[2..]; - if q_part.len() > 5 { - return Err(crate::headers::Error::invalid()); - } - match q_part.parse::() { - Ok(q_value) => { - if (0f32..=1f32).contains(&q_value) { - quality = q_value; - raw_item = second; - } else { - return Err(crate::headers::Error::invalid()); - } - } - Err(_) => return Err(crate::headers::Error::invalid()), - } + quality = Quality::from_str(first)?; + raw_item = second; } } match raw_item.parse::() { // we already checked above that the quality is within range - Ok(item) => Ok(QualityValue::new(item, from_f32(quality))), + Ok(item) => Ok(QualityValue::new(item, quality)), Err(_) => Err(crate::headers::Error::invalid()), } } diff --git a/rama-http/src/headers/util/csv.rs b/rama-http-types/src/headers/util/csv.rs similarity index 89% rename from rama-http/src/headers/util/csv.rs rename to rama-http-types/src/headers/util/csv.rs index c9f65c59..350a072d 100644 --- a/rama-http/src/headers/util/csv.rs +++ b/rama-http-types/src/headers/util/csv.rs @@ -8,7 +8,7 @@ use crate::HeaderValue; use crate::headers::Error; /// Reads a comma-delimited raw header into a Vec. -pub(crate) fn from_comma_delimited<'i, I, T, E>(values: &mut I) -> Result +pub fn from_comma_delimited<'i, I, T, E>(values: &mut I) -> Result where I: Iterator, T: ::std::str::FromStr, @@ -30,7 +30,7 @@ where } /// Format an array into a comma-delimited string. -pub(crate) fn fmt_comma_delimited( +pub fn fmt_comma_delimited( f: &mut fmt::Formatter, mut iter: impl Iterator, ) -> fmt::Result { diff --git a/rama-http-types/src/headers/util/mod.rs b/rama-http-types/src/headers/util/mod.rs new file mode 100644 index 00000000..d4996f0b --- /dev/null +++ b/rama-http-types/src/headers/util/mod.rs @@ -0,0 +1 @@ +pub mod csv; diff --git a/rama-http/Cargo.toml b/rama-http/Cargo.toml index ac596aa2..18f7ebcc 100644 --- a/rama-http/Cargo.toml +++ b/rama-http/Cargo.toml @@ -33,10 +33,6 @@ bytes = { workspace = true } const_format = { workspace = true } csv = { workspace = true } futures-lite = { workspace = true } -headers = { workspace = true } -http = { workspace = true } -http-body = { workspace = true } -http-body-util = { workspace = true } http-range-header = { workspace = true } httpdate = { workspace = true } iri-string = { workspace = true } diff --git a/rama-http/src/headers/forwarded/exotic_forward_ip.rs b/rama-http/src/headers/forwarded/exotic_forward_ip.rs index e715801f..f4814846 100644 --- a/rama-http/src/headers/forwarded/exotic_forward_ip.rs +++ b/rama-http/src/headers/forwarded/exotic_forward_ip.rs @@ -91,7 +91,7 @@ macro_rules! exotic_forward_ip_headers { fn decode<'i, I: Iterator>( values: &mut I, - ) -> Result { + ) -> Result { Ok($name( values .next() diff --git a/rama-http/src/headers/mod.rs b/rama-http/src/headers/mod.rs index 06dc4a44..ef8e3e0b 100644 --- a/rama-http/src/headers/mod.rs +++ b/rama-http/src/headers/mod.rs @@ -20,12 +20,13 @@ //! //! ```rust //! use rama_http::{headers::Header, HeaderName, HeaderValue}; +//! use rama_http_types::{header, headers}; //! //! struct Dnt(bool); //! //! impl Header for Dnt { //! fn name() -> &'static HeaderName { -//! &http::header::DNT +//! &header::DNT //! } //! //! fn decode<'i, I>(values: &mut I) -> Result @@ -67,7 +68,7 @@ pub use rama_http_types::headers::{Header, HeaderMapExt}; #[doc(inline)] pub use rama_http_types::headers::{ - AcceptRanges, AccessControlAllowCredentials, AccessControlAllowHeaders, + Accept, AcceptRanges, AccessControlAllowCredentials, AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlAllowOrigin, AccessControlExposeHeaders, AccessControlMaxAge, AccessControlRequestHeaders, AccessControlRequestMethod, Age, Allow, Authorization, CacheControl, Connection, ContentDisposition, ContentEncoding, ContentLength, @@ -78,9 +79,11 @@ pub use rama_http_types::headers::{ StrictTransportSecurity, Te, TransferEncoding, Upgrade, UserAgent, Vary, }; -mod common; #[doc(inline)] -pub use common::Accept; +pub use rama_http_types::headers::specifier::{Quality, QualityValue}; + +#[doc(inline)] +pub use rama_http_types::headers::util; mod forwarded; #[doc(inline)] @@ -93,13 +96,10 @@ pub mod authorization { //! Authorization header and types. #[doc(inline)] - pub use ::headers::authorization::Credentials; + pub use rama_http_types::headers::authorization::Credentials; #[doc(inline)] - pub use ::headers::authorization::{Authorization, Basic, Bearer}; + pub use rama_http_types::headers::authorization::{Authorization, Basic, Bearer}; } #[doc(inline)] pub use ::rama_http_types::headers::HeaderExt; - -pub(crate) mod util; -pub use util::quality_value::{Quality, QualityValue}; diff --git a/rama-http/src/headers/util/mod.rs b/rama-http/src/headers/util/mod.rs deleted file mode 100644 index af107ff1..00000000 --- a/rama-http/src/headers/util/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub(crate) mod csv; -/// Internal utility functions for headers. -pub(crate) mod quality_value; diff --git a/rama-http/src/io/request.rs b/rama-http/src/io/request.rs index d890f0e5..81a4342a 100644 --- a/rama-http/src/io/request.rs +++ b/rama-http/src/io/request.rs @@ -85,7 +85,7 @@ where for (name, value) in header_map { match parts.version { - http::Version::HTTP_2 | http::Version::HTTP_3 => { + rama_http_types::Version::HTTP_2 | rama_http_types::Version::HTTP_3 => { // write lower-case for H2/H3 w.write_all( format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?) diff --git a/rama-http/src/io/response.rs b/rama-http/src/io/response.rs index 260924fd..2061befb 100644 --- a/rama-http/src/io/response.rs +++ b/rama-http/src/io/response.rs @@ -73,7 +73,7 @@ where for (name, value) in header_map { match parts.version { - http::Version::HTTP_2 | http::Version::HTTP_3 => { + rama_http_types::Version::HTTP_2 | rama_http_types::Version::HTTP_3 => { // write lower-case for H2/H3 w.write_all( format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?) diff --git a/rama-http/src/layer/auth/add_authorization.rs b/rama-http/src/layer/auth/add_authorization.rs index d65e8066..3f549c3a 100644 --- a/rama-http/src/layer/auth/add_authorization.rs +++ b/rama-http/src/layer/auth/add_authorization.rs @@ -287,9 +287,13 @@ where mut req: Request, ) -> Result { if let Some(value) = &self.value { - if !self.if_not_present || !req.headers().contains_key(http::header::AUTHORIZATION) { + if !self.if_not_present + || !req + .headers() + .contains_key(rama_http_types::header::AUTHORIZATION) + { req.headers_mut() - .insert(http::header::AUTHORIZATION, value.clone()); + .insert(rama_http_types::header::AUTHORIZATION, value.clone()); } } self.inner.serve(ctx, req).await @@ -344,7 +348,10 @@ mod tests { async fn making_header_sensitive() { let svc = ValidateRequestHeaderLayer::bearer("foo").layer(service_fn( |request: Request| async move { - let auth = request.headers().get(http::header::AUTHORIZATION).unwrap(); + let auth = request + .headers() + .get(rama_http_types::header::AUTHORIZATION) + .unwrap(); assert!(auth.is_sensitive()); Ok::<_, Infallible>(Response::new(Body::empty())) diff --git a/rama-http/src/layer/auth/async_require_authorization.rs b/rama-http/src/layer/auth/async_require_authorization.rs index 1b72abba..52c6a8c3 100644 --- a/rama-http/src/layer/auth/async_require_authorization.rs +++ b/rama-http/src/layer/auth/async_require_authorization.rs @@ -271,7 +271,7 @@ mod tests { let authorized = request .headers() .get(header::AUTHORIZATION) - .and_then(|it: &http::HeaderValue| it.to_str().ok()) + .and_then(|it: &rama_http_types::HeaderValue| it.to_str().ok()) .and_then(|it| it.strip_prefix("Bearer ")) .map(|it| it == "69420") .unwrap_or(false); diff --git a/rama-http/src/layer/body_limit.rs b/rama-http/src/layer/body_limit.rs index 574805c5..1a7af821 100644 --- a/rama-http/src/layer/body_limit.rs +++ b/rama-http/src/layer/body_limit.rs @@ -85,7 +85,10 @@ impl Service> for BodyLimitService where S: Service>, State: Clone + Send + Sync + 'static, - ReqBody: http_body::Body> + Send + Sync + 'static, + ReqBody: rama_http_types::dep::http_body::Body> + + Send + + Sync + + 'static, { type Response = S::Response; type Error = S::Error; diff --git a/rama-http/src/layer/catch_panic.rs b/rama-http/src/layer/catch_panic.rs index 7cc97b17..8caa8d72 100644 --- a/rama-http/src/layer/catch_panic.rs +++ b/rama-http/src/layer/catch_panic.rs @@ -280,7 +280,7 @@ impl ResponseForPanic for DefaultResponseForPanic { #[allow(clippy::declare_interior_mutable_const)] const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8"); res.headers_mut() - .insert(http::header::CONTENT_TYPE, TEXT_PLAIN); + .insert(rama_http_types::header::CONTENT_TYPE, TEXT_PLAIN); res } diff --git a/rama-http/src/layer/classify/status_in_range_is_error.rs b/rama-http/src/layer/classify/status_in_range_is_error.rs index b23569da..db3e4db2 100644 --- a/rama-http/src/layer/classify/status_in_range_is_error.rs +++ b/rama-http/src/layer/classify/status_in_range_is_error.rs @@ -53,7 +53,7 @@ impl ClassifyResponse for StatusInRangeAsFailures { fn classify_response( self, - res: &http::Response, + res: &rama_http_types::Response, ) -> ClassifiedResponse { if self.range.contains(&res.status().as_u16()) { let class = StatusInRangeFailureClass::StatusCode(res.status()); diff --git a/rama-http/src/layer/compression/body.rs b/rama-http/src/layer/compression/body.rs index 3e616184..b90d3f1f 100644 --- a/rama-http/src/layer/compression/body.rs +++ b/rama-http/src/layer/compression/body.rs @@ -143,11 +143,11 @@ where } } - fn size_hint(&self) -> http_body::SizeHint { + fn size_hint(&self) -> rama_http_types::dep::http_body::SizeHint { if let BodyInner::Identity { inner } = &self.inner { inner.size_hint() } else { - http_body::SizeHint::new() + rama_http_types::dep::http_body::SizeHint::new() } } diff --git a/rama-http/src/layer/compression/layer.rs b/rama-http/src/layer/compression/layer.rs index 59ec6fee..2e98327c 100644 --- a/rama-http/src/layer/compression/layer.rs +++ b/rama-http/src/layer/compression/layer.rs @@ -1,7 +1,8 @@ use super::predicate::DefaultPredicate; use super::{Compression, Predicate}; -use crate::layer::util::compression::{AcceptEncoding, CompressionLevel}; +use crate::layer::util::compression::CompressionLevel; use rama_core::Layer; +use rama_http_types::headers::encoding::AcceptEncoding; /// Compress response bodies of the underlying service. /// diff --git a/rama-http/src/layer/compression/mod.rs b/rama-http/src/layer/compression/mod.rs index 2d8c66f9..c44efb2d 100644 --- a/rama-http/src/layer/compression/mod.rs +++ b/rama-http/src/layer/compression/mod.rs @@ -117,9 +117,9 @@ mod tests { struct Always; impl Predicate for Always { - fn should_compress(&self, _: &http::Response) -> bool + fn should_compress(&self, _: &rama_http_types::Response) -> bool where - B: http_body::Body, + B: rama_http_types::dep::http_body::Body, { true } @@ -295,9 +295,9 @@ mod tests { #[allow(clippy::dbg_macro)] impl Predicate for EveryOtherResponse { - fn should_compress(&self, _: &http::Response) -> bool + fn should_compress(&self, _: &rama_http_types::Response) -> bool where - B: http_body::Body, + B: rama_http_types::dep::http_body::Body, { let mut guard = self.0.write().unwrap(); let should_compress = *guard % 2 != 0; diff --git a/rama-http/src/layer/compression/predicate.rs b/rama-http/src/layer/compression/predicate.rs index 967fc508..505c225c 100644 --- a/rama-http/src/layer/compression/predicate.rs +++ b/rama-http/src/layer/compression/predicate.rs @@ -13,7 +13,7 @@ use std::{fmt, sync::Arc}; /// Predicate used to determine if a response should be compressed or not. pub trait Predicate: Clone { /// Should this response be compressed or not? - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body; @@ -36,7 +36,7 @@ impl Predicate for F where F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone, { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -52,7 +52,7 @@ impl Predicate for Option where T: Predicate, { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -76,7 +76,7 @@ where Lhs: Predicate, Rhs: Predicate, { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -136,7 +136,7 @@ impl Default for DefaultPredicate { } impl Predicate for DefaultPredicate { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -168,7 +168,7 @@ impl Default for SizeAbove { } impl Predicate for SizeAbove { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -225,7 +225,7 @@ impl NotForContentType { } impl Predicate for NotForContentType { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -263,7 +263,7 @@ impl fmt::Debug for Str { } } -fn content_type(response: &http::Response) -> &str { +fn content_type(response: &rama_http_types::Response) -> &str { response .headers() .get(header::CONTENT_TYPE) diff --git a/rama-http/src/layer/compression/service.rs b/rama-http/src/layer/compression/service.rs index 8d48fa36..c3bfe016 100644 --- a/rama-http/src/layer/compression/service.rs +++ b/rama-http/src/layer/compression/service.rs @@ -4,9 +4,10 @@ use super::body::BodyInner; use super::predicate::{DefaultPredicate, Predicate}; use crate::dep::http_body::Body; use crate::layer::util::compression::WrapBody; -use crate::layer::util::{compression::AcceptEncoding, content_encoding::Encoding}; use crate::{Request, Response, header}; use rama_core::{Context, Service}; +use rama_http_types::HeaderValue; +use rama_http_types::headers::encoding::{AcceptEncoding, Encoding}; use rama_utils::macros::define_inner_service_accessors; /// Compress response bodies of the underlying service. @@ -192,7 +193,7 @@ where ctx: Context, req: Request, ) -> Result { - let encoding = Encoding::from_headers(req.headers(), self.accept); + let encoding = Encoding::from_accept_encoding_headers(req.headers(), self.accept); let res = self.inner.serve(ctx, req).await?; @@ -259,7 +260,7 @@ where parts .headers - .insert(header::CONTENT_ENCODING, encoding.into_header_value()); + .insert(header::CONTENT_ENCODING, HeaderValue::from(encoding)); let res = Response::from_parts(parts, body); Ok(res) diff --git a/rama-http/src/layer/cors/tests.rs b/rama-http/src/layer/cors/tests.rs index 312fd256..2642f34a 100644 --- a/rama-http/src/layer/cors/tests.rs +++ b/rama-http/src/layer/cors/tests.rs @@ -57,7 +57,7 @@ async fn test_allow_origin_async_predicate() { }); let valid_origin = HeaderValue::from_static("http://example.com"); - let parts = http::Request::new("hello world").into_parts().0; + let parts = rama_http_types::Request::new("hello world").into_parts().0; let header = allow_origin .to_future(Some(&valid_origin), &parts) @@ -67,7 +67,7 @@ async fn test_allow_origin_async_predicate() { assert_eq!(header.1, valid_origin); let invalid_origin = HeaderValue::from_static("http://example.org"); - let parts = http::Request::new("hello world").into_parts().0; + let parts = rama_http_types::Request::new("hello world").into_parts().0; let res = allow_origin.to_future(Some(&invalid_origin), &parts).await; assert!(res.is_none()); diff --git a/rama-http/src/layer/decompression/layer.rs b/rama-http/src/layer/decompression/layer.rs index 37e0d153..df90d234 100644 --- a/rama-http/src/layer/decompression/layer.rs +++ b/rama-http/src/layer/decompression/layer.rs @@ -1,6 +1,6 @@ use super::Decompression; -use crate::layer::util::compression::AcceptEncoding; use rama_core::Layer; +use rama_http_types::headers::encoding::AcceptEncoding; /// Decompresses response bodies of the underlying service. /// diff --git a/rama-http/src/layer/decompression/request/layer.rs b/rama-http/src/layer/decompression/request/layer.rs index 985e8e62..6e1b97e5 100644 --- a/rama-http/src/layer/decompression/request/layer.rs +++ b/rama-http/src/layer/decompression/request/layer.rs @@ -1,6 +1,6 @@ use super::service::RequestDecompression; -use crate::layer::util::compression::AcceptEncoding; use rama_core::Layer; +use rama_http_types::headers::encoding::AcceptEncoding; /// Decompresses request bodies and calls its underlying service. /// diff --git a/rama-http/src/layer/decompression/request/service.rs b/rama-http/src/layer/decompression/request/service.rs index be2aa2fe..2e35d54c 100644 --- a/rama-http/src/layer/decompression/request/service.rs +++ b/rama-http/src/layer/decompression/request/service.rs @@ -5,13 +5,13 @@ use crate::dep::http_body_util::{BodyExt, Empty, combinators::UnsyncBoxBody}; use crate::layer::{ decompression::DecompressionBody, decompression::body::BodyInner, - util::compression::{AcceptEncoding, CompressionLevel, WrapBody}, - util::content_encoding::SupportedEncodings, + util::compression::{CompressionLevel, WrapBody}, }; use crate::{HeaderValue, Request, Response, StatusCode, header}; use bytes::Buf; use rama_core::error::BoxError; use rama_core::{Context, Service}; +use rama_http_types::headers::encoding::{AcceptEncoding, SupportedEncodings}; use rama_utils::macros::define_inner_service_accessors; /// Decompresses request bodies and calls its underlying service. @@ -124,7 +124,7 @@ where .header( header::ACCEPT_ENCODING, accept - .to_header_value() + .maybe_to_header_value() .unwrap_or(HeaderValue::from_static("identity")), ) .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) diff --git a/rama-http/src/layer/decompression/service.rs b/rama-http/src/layer/decompression/service.rs index b132d3cf..d3c23c02 100644 --- a/rama-http/src/layer/decompression/service.rs +++ b/rama-http/src/layer/decompression/service.rs @@ -2,16 +2,14 @@ use std::fmt; use super::{DecompressionBody, body::BodyInner}; use crate::dep::http_body::Body; -use crate::layer::util::{ - compression::{AcceptEncoding, CompressionLevel, WrapBody}, - content_encoding::SupportedEncodings, -}; +use crate::layer::util::compression::{CompressionLevel, WrapBody}; use crate::{ Request, Response, header::{self, ACCEPT_ENCODING}, }; use rama_core::{Context, Service}; use rama_http_types::compression::DecompressIfPossible; +use rama_http_types::headers::encoding::{AcceptEncoding, SupportedEncodings}; use rama_utils::macros::define_inner_service_accessors; /// Decompresses response bodies of the underlying service. @@ -141,7 +139,7 @@ where mut req: Request, ) -> Result { if let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING) { - if let Some(accept) = self.accept.to_header_value() { + if let Some(accept) = self.accept.maybe_to_header_value() { entry.insert(accept); } } diff --git a/rama-http/src/layer/opentelemetry.rs b/rama-http/src/layer/opentelemetry.rs index 289ff56f..24ce3fcf 100644 --- a/rama-http/src/layer/opentelemetry.rs +++ b/rama-http/src/layer/opentelemetry.rs @@ -267,11 +267,11 @@ impl RequestMetricsService { attributes.push(KeyValue::new(HTTP_REQUEST_METHOD, req.method().to_string())); if let Some(http_version) = request_ctx.as_ref().and_then(|rc| match rc.http_version { - http::Version::HTTP_09 => Some("0.9"), - http::Version::HTTP_10 => Some("1.0"), - http::Version::HTTP_11 => Some("1.1"), - http::Version::HTTP_2 => Some("2"), - http::Version::HTTP_3 => Some("3"), + rama_http_types::Version::HTTP_09 => Some("0.9"), + rama_http_types::Version::HTTP_10 => Some("1.0"), + rama_http_types::Version::HTTP_11 => Some("1.1"), + rama_http_types::Version::HTTP_2 => Some("2"), + rama_http_types::Version::HTTP_3 => Some("3"), _ => None, }) { attributes.push(KeyValue::new(NETWORK_PROTOCOL_VERSION, http_version)); diff --git a/rama-http/src/layer/required_header/request.rs b/rama-http/src/layer/required_header/request.rs index 748b0de3..4ce8db2f 100644 --- a/rama-http/src/layer/required_header/request.rs +++ b/rama-http/src/layer/required_header/request.rs @@ -228,7 +228,7 @@ mod test { |_ctx: Context<()>, req: Request| async move { assert!(req.headers().contains_key(HOST)); assert!(req.headers().contains_key(USER_AGENT)); - Ok::<_, Infallible>(http::Response::new(Body::empty())) + Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty())) }, )); @@ -252,7 +252,7 @@ mod test { req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()), Some("foo") ); - Ok::<_, Infallible>(http::Response::new(Body::empty())) + Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty())) })); let req = Request::builder() @@ -275,7 +275,7 @@ mod test { req.headers().get(USER_AGENT).unwrap(), RAMA_ID_HEADER_VALUE.to_str().unwrap() ); - Ok::<_, Infallible>(http::Response::new(Body::empty())) + Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty())) })); let req = Request::builder() @@ -302,7 +302,7 @@ mod test { req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()), Some("foo") ); - Ok::<_, Infallible>(http::Response::new(Body::empty())) + Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty())) })); let req = Request::builder() diff --git a/rama-http/src/layer/retry/body.rs b/rama-http/src/layer/retry/body.rs index 5eefa01e..47d0ff15 100644 --- a/rama-http/src/layer/retry/body.rs +++ b/rama-http/src/layer/retry/body.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use rama_http_types::dep::http_body; #[derive(Debug, Clone)] /// A body that can be clone and used for requests that have to be rertried. diff --git a/rama-http/src/layer/retry/policy.rs b/rama-http/src/layer/retry/policy.rs index 859e69e8..fcb00d17 100644 --- a/rama-http/src/layer/retry/policy.rs +++ b/rama-http/src/layer/retry/policy.rs @@ -191,7 +191,7 @@ macro_rules! impl_retry_policy_either { async fn retry( &self, ctx: Context, - req: http::Request, + req: rama_http_types::Request, result: Result, ) -> PolicyResult { match self { @@ -204,8 +204,8 @@ macro_rules! impl_retry_policy_either { fn clone_input( &self, ctx: &Context, - req: &http::Request, - ) -> Option<(Context, http::Request)> { + req: &rama_http_types::Request, + ) -> Option<(Context, rama_http_types::Request)> { match self { $( rama_core::combinators::$id::$param(policy) => policy.clone_input(ctx, req), diff --git a/rama-http/src/layer/trace/body.rs b/rama-http/src/layer/trace/body.rs index 1a991860..9793807c 100644 --- a/rama-http/src/layer/trace/body.rs +++ b/rama-http/src/layer/trace/body.rs @@ -96,7 +96,7 @@ where self.inner.is_end_stream() } - fn size_hint(&self) -> http_body::SizeHint { + fn size_hint(&self) -> rama_http_types::dep::http_body::SizeHint { self.inner.size_hint() } } diff --git a/rama-http/src/layer/trace/on_response.rs b/rama-http/src/layer/trace/on_response.rs index f8941e35..fc44faa1 100644 --- a/rama-http/src/layer/trace/on_response.rs +++ b/rama-http/src/layer/trace/on_response.rs @@ -171,7 +171,7 @@ fn status(res: &Response) -> Option { // For simplicity, we simply check that the content type starts with "application/grpc". let is_grpc = res .headers() - .get(http::header::CONTENT_TYPE) + .get(rama_http_types::header::CONTENT_TYPE) .is_some_and(|value| value.as_bytes().starts_with("application/grpc".as_bytes())); if is_grpc { diff --git a/rama-http/src/layer/traffic_writer/response.rs b/rama-http/src/layer/traffic_writer/response.rs index 76b3d0dd..18564677 100644 --- a/rama-http/src/layer/traffic_writer/response.rs +++ b/rama-http/src/layer/traffic_writer/response.rs @@ -287,7 +287,7 @@ where .map_err(|err| OpaqueError::from_boxed(err.into())) .context("printer prepare: collect response body")? .to_bytes(); - let resp: http::Response = + let resp: rama_http_types::Response = Response::from_parts(parts.clone(), Body::from(body_bytes.clone())); self.writer.write_response(resp).await; Response::from_parts(parts, Body::from(body_bytes)) diff --git a/rama-http/src/layer/ua.rs b/rama-http/src/layer/ua.rs index 4249486e..bbf45b65 100644 --- a/rama-http/src/layer/ua.rs +++ b/rama-http/src/layer/ua.rs @@ -28,7 +28,7 @@ //! //! let _ = service //! .get("http://www.example.com") -//! .typed_header(headers::UserAgent::from_static(UA)) +//! .typed_header(rama_http_types::headers::UserAgent::from_static(UA)) //! .send(Context::default()) //! .await //! .unwrap(); diff --git a/rama-http/src/layer/util/compression.rs b/rama-http/src/layer/util/compression.rs index a4697888..780f9d78 100644 --- a/rama-http/src/layer/util/compression.rs +++ b/rama-http/src/layer/util/compression.rs @@ -1,7 +1,5 @@ //! Types used by compression and decompression middleware. -use super::content_encoding::SupportedEncodings; -use crate::HeaderValue; use crate::dep::http_body::{Body, Frame}; use bytes::{Buf, Bytes, BytesMut}; use futures_lite::Stream; @@ -16,88 +14,6 @@ use std::{ use tokio::io::AsyncRead; use tokio_util::io::StreamReader; -#[derive(Debug, Clone, Copy)] -pub(crate) struct AcceptEncoding { - pub(crate) gzip: bool, - pub(crate) deflate: bool, - pub(crate) br: bool, - pub(crate) zstd: bool, -} - -impl AcceptEncoding { - #[allow(dead_code)] - pub(crate) fn to_header_value(self) -> Option { - let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) { - (true, true, true, false) => "gzip,deflate,br", - (true, true, false, false) => "gzip,deflate", - (true, false, true, false) => "gzip,br", - (true, false, false, false) => "gzip", - (false, true, true, false) => "deflate,br", - (false, true, false, false) => "deflate", - (false, false, true, false) => "br", - (true, true, true, true) => "zstd,gzip,deflate,br", - (true, true, false, true) => "zstd,gzip,deflate", - (true, false, true, true) => "zstd,gzip,br", - (true, false, false, true) => "zstd,gzip", - (false, true, true, true) => "zstd,deflate,br", - (false, true, false, true) => "zstd,deflate", - (false, false, true, true) => "zstd,br", - (false, false, false, true) => "zstd", - (false, false, false, false) => return None, - }; - Some(HeaderValue::from_static(accept)) - } - - #[allow(dead_code)] - pub(crate) fn set_gzip(&mut self, enable: bool) { - self.gzip = enable; - } - - #[allow(dead_code)] - pub(crate) fn set_deflate(&mut self, enable: bool) { - self.deflate = enable; - } - - #[allow(dead_code)] - pub(crate) fn set_br(&mut self, enable: bool) { - self.br = enable; - } - - #[allow(dead_code)] - pub(crate) fn set_zstd(&mut self, enable: bool) { - self.zstd = enable; - } -} - -impl SupportedEncodings for AcceptEncoding { - fn gzip(&self) -> bool { - self.gzip - } - - fn deflate(&self) -> bool { - self.deflate - } - - fn br(&self) -> bool { - self.br - } - - fn zstd(&self) -> bool { - self.zstd - } -} - -impl Default for AcceptEncoding { - fn default() -> Self { - AcceptEncoding { - gzip: true, - deflate: true, - br: true, - zstd: true, - } - } -} - /// A `Body` that has been converted into an `AsyncRead`. pub(crate) type AsyncReadBody = StreamReader, ::Error>, ::Data>; @@ -329,7 +245,7 @@ where } #[inline] - fn size_hint(&self) -> http_body::SizeHint { + fn size_hint(&self) -> rama_http_types::dep::http_body::SizeHint { self.body.size_hint() } } diff --git a/rama-http/src/layer/util/mod.rs b/rama-http/src/layer/util/mod.rs index 1fdeba4e..b55fcaa1 100644 --- a/rama-http/src/layer/util/mod.rs +++ b/rama-http/src/layer/util/mod.rs @@ -2,5 +2,3 @@ #[cfg(feature = "compression")] pub(crate) mod compression; - -pub(crate) mod content_encoding; diff --git a/rama-http/src/matcher/mod.rs b/rama-http/src/matcher/mod.rs index b9d191ef..42b6fd1e 100644 --- a/rama-http/src/matcher/mod.rs +++ b/rama-http/src/matcher/mod.rs @@ -1,9 +1,9 @@ -//! [`service::Matcher`]s implementations to match on [`http::Request`]s. +//! [`service::Matcher`]s implementations to match on [`rama_http_types::Request`]s. //! //! See [`service::matcher` module] for more information. //! //! [`service::Matcher`]: rama_core::matcher::Matcher -//! [`http::Request`]: crate::Request +//! [`rama_http_types::Request`]: crate::Request //! [`service::matcher` module]: rama_core use crate::Request; use rama_core::{Context, context::Extensions, matcher::IteratorMatcherExt}; @@ -455,7 +455,10 @@ impl HttpMatcher { } /// Create a [`HeaderMatcher`] matcher. - pub fn header(name: http::header::HeaderName, value: http::header::HeaderValue) -> Self { + pub fn header( + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, + ) -> Self { Self { kind: HttpMatcherKind::Header(HeaderMatcher::is(name, value)), negate: false, @@ -467,8 +470,8 @@ impl HttpMatcher { /// See [`HeaderMatcher`] for more information. pub fn and_header( self, - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { self.and(Self::header(name, value)) } @@ -478,15 +481,15 @@ impl HttpMatcher { /// See [`HeaderMatcher`] for more information. pub fn or_header( self, - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { self.or(Self::header(name, value)) } /// Create a [`HeaderMatcher`] matcher when the given header exists /// to match on the existence of a header. - pub fn header_exists(name: http::header::HeaderName) -> Self { + pub fn header_exists(name: rama_http_types::header::HeaderName) -> Self { Self { kind: HttpMatcherKind::Header(HeaderMatcher::exists(name)), negate: false, @@ -497,7 +500,7 @@ impl HttpMatcher { /// on top of the existing set of [`HttpMatcher`] matchers. /// /// See [`HeaderMatcher`] for more information. - pub fn and_header_exists(self, name: http::header::HeaderName) -> Self { + pub fn and_header_exists(self, name: rama_http_types::header::HeaderName) -> Self { self.and(Self::header_exists(name)) } @@ -505,14 +508,14 @@ impl HttpMatcher { /// as an alternative to the existing set of [`HttpMatcher`] matchers. /// /// See [`HeaderMatcher`] for more information. - pub fn or_header_exists(self, name: http::header::HeaderName) -> Self { + pub fn or_header_exists(self, name: rama_http_types::header::HeaderName) -> Self { self.or(Self::header_exists(name)) } /// Create a [`HeaderMatcher`] matcher to match on it containing the given value. pub fn header_contains( - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { Self { kind: HttpMatcherKind::Header(HeaderMatcher::contains(name, value)), @@ -526,8 +529,8 @@ impl HttpMatcher { /// See [`HeaderMatcher`] for more information. pub fn and_header_contains( self, - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { self.and(Self::header_contains(name, value)) } @@ -538,8 +541,8 @@ impl HttpMatcher { /// See [`HeaderMatcher`] for more information. pub fn or_header_contains( self, - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { self.or(Self::header_contains(name, value)) } diff --git a/rama-http/src/service/fs/mod.rs b/rama-http/src/service/fs/mod.rs index 8663f7a4..1a8dc83d 100644 --- a/rama-http/src/service/fs/mod.rs +++ b/rama-http/src/service/fs/mod.rs @@ -2,8 +2,8 @@ use bytes::Bytes; use futures_lite::Stream; -use http_body::{Body, Frame}; use pin_project_lite::pin_project; +use rama_http_types::dep::http_body::{Body, Frame}; use std::{ fmt, io, pin::Pin, diff --git a/rama-http/src/service/fs/serve_dir/future.rs b/rama-http/src/service/fs/serve_dir/future.rs index 142bddda..128c9715 100644 --- a/rama-http/src/service/fs/serve_dir/future.rs +++ b/rama-http/src/service/fs/serve_dir/future.rs @@ -3,11 +3,12 @@ use crate::{ Body, HeaderValue, Request, Response, StatusCode, dep::http_body_util::BodyExt, header::{self, ALLOW}, - layer::util::content_encoding::Encoding, service::fs::AsyncReadBody, }; use bytes::Bytes; use rama_core::{Context, Service, error::BoxError}; +use rama_http_types::dep::http_body; +use rama_http_types::headers::encoding::Encoding; use std::{convert::Infallible, io}; pub(super) async fn consume_open_file_result( @@ -25,7 +26,8 @@ where Ok(OpenFileOutput::Redirect { location }) => { let mut res = response_with_status(StatusCode::TEMPORARY_REDIRECT); - res.headers_mut().insert(http::header::LOCATION, location); + res.headers_mut() + .insert(rama_http_types::header::LOCATION, location); Ok(res) } @@ -123,7 +125,7 @@ fn build_response(output: FileOpened) -> Response { .maybe_encoding .filter(|encoding| *encoding != Encoding::Identity) { - builder = builder.header(header::CONTENT_ENCODING, encoding.into_header_value()); + builder = builder.header(header::CONTENT_ENCODING, HeaderValue::from(encoding)); } if let Some(last_modified) = output.last_modified { diff --git a/rama-http/src/service/fs/serve_dir/mod.rs b/rama-http/src/service/fs/serve_dir/mod.rs index 0db993aa..4981ca95 100644 --- a/rama-http/src/service/fs/serve_dir/mod.rs +++ b/rama-http/src/service/fs/serve_dir/mod.rs @@ -1,13 +1,11 @@ -use crate::dep::http_body::Body as HttpBody; -use crate::layer::{ - set_status::SetStatus, - util::content_encoding::{SupportedEncodings, encodings}, -}; +use crate::dep::http_body::{self, Body as HttpBody}; +use crate::layer::set_status::SetStatus; use crate::{Body, HeaderValue, Method, Request, Response, StatusCode, header}; use bytes::Bytes; use percent_encoding::percent_decode; use rama_core::error::BoxError; use rama_core::{Context, Service}; +use rama_http_types::headers::encoding::{SupportedEncodings, parse_accept_encoding_headers}; use std::{ convert::Infallible, path::{Component, Path, PathBuf}, @@ -536,7 +534,7 @@ impl ServeDir { .and_then(|value| value.to_str().ok()) .map(|s| s.to_owned()); - let negotiated_encodings: Vec<_> = encodings( + let negotiated_encodings: Vec<_> = parse_accept_encoding_headers( req.headers(), self.precompressed_variants.unwrap_or_default(), ) diff --git a/rama-http/src/service/fs/serve_dir/open_file.rs b/rama-http/src/service/fs/serve_dir/open_file.rs index 198703bc..6d92cf37 100644 --- a/rama-http/src/service/fs/serve_dir/open_file.rs +++ b/rama-http/src/service/fs/serve_dir/open_file.rs @@ -2,9 +2,9 @@ use super::{ ServeVariant, headers::{IfModifiedSince, IfUnmodifiedSince, LastModified}, }; -use crate::layer::util::content_encoding::{Encoding, QValue}; use crate::{HeaderValue, Method, Request, Uri, header}; use http_range_header::RangeUnsatisfiableError; +use rama_http_types::headers::{encoding::Encoding, specifier::QualityValue}; use std::{ ffi::OsStr, fs::Metadata, @@ -40,7 +40,7 @@ pub(super) async fn open_file( variant: ServeVariant, mut path_to_file: PathBuf, req: Request, - negotiated_encodings: Vec<(Encoding, QValue)>, + negotiated_encodings: Vec>, range_header: Option, buf_chunk_size: usize, ) -> io::Result { @@ -172,9 +172,10 @@ fn check_modified_headers( // to the corresponding file extension for the encoding. fn preferred_encoding( path: &mut PathBuf, - negotiated_encoding: &[(Encoding, QValue)], + negotiated_encoding: &[QualityValue], ) -> Option { - let preferred_encoding = Encoding::preferred_encoding(negotiated_encoding.iter().copied()); + let preferred_encoding = + Encoding::maybe_preferred_encoding(negotiated_encoding.iter().copied()); if let Some(file_extension) = preferred_encoding.and_then(|encoding| encoding.to_file_extension()) @@ -199,7 +200,7 @@ fn preferred_encoding( // file the uncompressed file is used as a fallback. async fn open_file_with_fallback( mut path: PathBuf, - mut negotiated_encoding: Vec<(Encoding, QValue)>, + mut negotiated_encoding: Vec>, ) -> io::Result<(File, Option)> { let (file, encoding) = loop { // Get the preferred encoding among the negotiated ones. @@ -211,8 +212,7 @@ async fn open_file_with_fallback( // to reset the path before the next iteration. path.set_extension(OsStr::new("")); // Remove the encoding from the negotiated_encodings since the file doesn't exist - negotiated_encoding - .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); + negotiated_encoding.retain(|qv| qv.value != encoding); continue; } (Err(err), _) => return Err(err), @@ -226,7 +226,7 @@ async fn open_file_with_fallback( // file the uncompressed file is used as a fallback. async fn file_metadata_with_fallback( mut path: PathBuf, - mut negotiated_encoding: Vec<(Encoding, QValue)>, + mut negotiated_encoding: Vec>, ) -> io::Result<(Metadata, Option)> { let (file, encoding) = loop { // Get the preferred encoding among the negotiated ones. @@ -238,8 +238,7 @@ async fn file_metadata_with_fallback( // to reset the path before the next iteration. path.set_extension(OsStr::new("")); // Remove the encoding from the negotiated_encodings since the file doesn't exist - negotiated_encoding - .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); + negotiated_encoding.retain(|qv| qv.value != encoding); continue; } (Err(err), _) => return Err(err), @@ -288,7 +287,7 @@ async fn is_dir(path_to_file: &Path) -> bool { } fn append_slash_on_path(uri: Uri) -> Uri { - let http::uri::Parts { + let rama_http_types::dep::http::uri::Parts { scheme, authority, path_and_query, diff --git a/rama-http/src/service/fs/serve_dir/tests.rs b/rama-http/src/service/fs/serve_dir/tests.rs index 9f313f31..d1e036c9 100644 --- a/rama-http/src/service/fs/serve_dir/tests.rs +++ b/rama-http/src/service/fs/serve_dir/tests.rs @@ -395,7 +395,7 @@ async fn redirect_to_trailing_slash_on_dir() { assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT); - let location = &res.headers()[http::header::LOCATION]; + let location = &res.headers()[rama_http_types::header::LOCATION]; assert_eq!(location, "/src/"); } diff --git a/rama-http/src/service/fs/serve_file.rs b/rama-http/src/service/fs/serve_file.rs index 77220613..3c0c45b0 100644 --- a/rama-http/src/service/fs/serve_file.rs +++ b/rama-http/src/service/fs/serve_file.rs @@ -148,7 +148,7 @@ mod compression_tests { #[cfg(feature = "compression")] async fn precompressed_zstd() { use async_compression::tokio::bufread::ZstdDecoder; - use http_body_util::BodyExt; + use rama_http_types::dep::http_body_util::BodyExt; use tokio::io::AsyncReadExt; let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_zstd(); diff --git a/rama-http/src/service/web/endpoint/extract/authority.rs b/rama-http/src/service/web/endpoint/extract/authority.rs index 86d69224..2ff6b768 100644 --- a/rama-http/src/service/web/endpoint/extract/authority.rs +++ b/rama-http/src/service/web/endpoint/extract/authority.rs @@ -1,7 +1,7 @@ use super::FromRequestContextRefPair; -use crate::dep::http::request::Parts; use crate::utils::macros::define_http_rejection; use rama_core::Context; +use rama_http_types::dep::http::request::Parts; use rama_net::address; use rama_net::http::RequestContext; use rama_utils::macros::impl_deref; @@ -80,7 +80,7 @@ mod tests { test_authority_from_request( "/", "some-domain:123", - vec![(&http::header::HOST, "some-domain:123")], + vec![(&rama_http_types::header::HOST, "some-domain:123")], ) .await; } @@ -102,7 +102,7 @@ mod tests { "some-domain:456", vec![ (&X_FORWARDED_HOST, "some-domain:456"), - (&http::header::HOST, "some-domain:123"), + (&rama_http_types::header::HOST, "some-domain:123"), ], ) .await; diff --git a/rama-http/src/service/web/endpoint/extract/body/csv.rs b/rama-http/src/service/web/endpoint/extract/body/csv.rs index f37836fe..01353cf8 100644 --- a/rama-http/src/service/web/endpoint/extract/body/csv.rs +++ b/rama-http/src/service/web/endpoint/extract/body/csv.rs @@ -100,9 +100,12 @@ mod test { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) - .header(http::header::CONTENT_TYPE, "text/csv; charset=utf-8") + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) + .header( + rama_http_types::header::CONTENT_TYPE, + "text/csv; charset=utf-8", + ) .body("name,age,alive\nglen,42,\nadr,40,true\n".into()) .unwrap(); let resp = service.serve(Context::default(), req).await.unwrap(); @@ -122,9 +125,9 @@ mod test { let service = WebService::default() .post("/", |Csv(_): Csv>| async move { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) - .header(http::header::CONTENT_TYPE, "text/plain") + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) + .header(rama_http_types::header::CONTENT_TYPE, "text/plain") .body(r#"{"name": "glen", "age": 42}"#.into()) .unwrap(); let resp = service.serve(Context::default(), req).await.unwrap(); @@ -143,9 +146,12 @@ mod test { let service = WebService::default() .post("/", |Csv(_): Csv>| async move { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) - .header(http::header::CONTENT_TYPE, "text/csv; charset=utf-8") + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) + .header( + rama_http_types::header::CONTENT_TYPE, + "text/csv; charset=utf-8", + ) // the missing column last line should trigger an error .body("name,age,alive\nglen,42,\nadr,40\n".into()) .unwrap(); diff --git a/rama-http/src/service/web/endpoint/extract/body/json.rs b/rama-http/src/service/web/endpoint/extract/body/json.rs index c5862806..3c6959db 100644 --- a/rama-http/src/service/web/endpoint/extract/body/json.rs +++ b/rama-http/src/service/web/endpoint/extract/body/json.rs @@ -85,10 +85,10 @@ mod test { assert_eq!(body.alive, None); }); - let req = http::Request::builder() - .method(http::Method::POST) + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) .header( - http::header::CONTENT_TYPE, + rama_http_types::header::CONTENT_TYPE, "application/json; charset=utf-8", ) .body(r#"{"name": "glen", "age": 42}"#.into()) @@ -109,9 +109,9 @@ mod test { let service = WebService::default().post("/", |Json(_): Json| async move { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) - .header(http::header::CONTENT_TYPE, "text/plain") + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) + .header(rama_http_types::header::CONTENT_TYPE, "text/plain") .body(r#"{"name": "glen", "age": 42}"#.into()) .unwrap(); let resp = service.serve(Context::default(), req).await.unwrap(); @@ -130,10 +130,10 @@ mod test { let service = WebService::default().post("/", |Json(_): Json| async move { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) .header( - http::header::CONTENT_TYPE, + rama_http_types::header::CONTENT_TYPE, "application/json; charset=utf-8", ) .body(r#"deal with it, or not?!"#.into()) diff --git a/rama-http/src/service/web/endpoint/extract/host.rs b/rama-http/src/service/web/endpoint/extract/host.rs index f49c68ce..e65af74b 100644 --- a/rama-http/src/service/web/endpoint/extract/host.rs +++ b/rama-http/src/service/web/endpoint/extract/host.rs @@ -1,7 +1,7 @@ use super::FromRequestContextRefPair; -use crate::dep::http::request::Parts; use crate::utils::macros::define_http_rejection; use rama_core::Context; +use rama_http_types::dep::http::request::Parts; use rama_net::address; use rama_net::http::RequestContext; use rama_utils::macros::impl_deref; @@ -75,7 +75,7 @@ mod tests { test_host_from_request( "/", "some-domain", - vec![(&http::header::HOST, "some-domain:123")], + vec![(&rama_http_types::header::HOST, "some-domain:123")], ) .await; } @@ -97,7 +97,7 @@ mod tests { "some-domain", vec![ (&X_FORWARDED_HOST, "some-domain:456"), - (&http::header::HOST, "some-domain:123"), + (&rama_http_types::header::HOST, "some-domain:123"), ], ) .await; diff --git a/rama-http/src/service/web/endpoint/extract/typed_header.rs b/rama-http/src/service/web/endpoint/extract/typed_header.rs index 482d8080..230dde28 100644 --- a/rama-http/src/service/web/endpoint/extract/typed_header.rs +++ b/rama-http/src/service/web/endpoint/extract/typed_header.rs @@ -128,7 +128,7 @@ impl TypedHeaderRejectionReason { impl IntoResponse for TypedHeaderRejection { fn into_response(self) -> Response { - (http::StatusCode::BAD_REQUEST, self.to_string()).into_response() + (rama_http_types::StatusCode::BAD_REQUEST, self.to_string()).into_response() } } From 366913ddfaa87d0419df21a0601a0374aef0e90d Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sun, 23 Feb 2025 23:35:05 +0100 Subject: [PATCH 31/39] improve decompress req detect logic in rama-ua emulate service decompression will no longer happen in case the client req had a compatible encoding with the response content-encoding --- rama-ua/src/emulate/service.rs | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 7c71c773..30d6e406 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -9,9 +9,10 @@ use rama_http_types::{ compression::DecompressIfPossible, conn::Http1ClientContextParams, header::{ - ACCEPT, ACCEPT_ENCODING, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, - COOKIE, HOST, ORIGIN, REFERER, USER_AGENT, + ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, HOST, ORIGIN, + REFERER, USER_AGENT, }, + headers::encoding::{Encoding, parse_accept_encoding_headers}, proto::{ h1::{ Http1HeaderMap, @@ -218,7 +219,7 @@ where Some(HttpAgent::Preserve), ); - let mut decompression_marker = None; + let mut original_requested_encodings = None; if preserve_http { tracing::trace!( @@ -254,7 +255,11 @@ where .unwrap_or_default(), }; - let requested_compression = original_headers.get(ACCEPT_ENCODING).is_some(); + original_requested_encodings = Some( + parse_accept_encoding_headers(&original_headers, true) + .map(|qv| qv.value) + .collect::>(), + ); let output_headers = merge_http_headers( base_http_headers, @@ -264,10 +269,6 @@ where is_secure_request, ); - if !requested_compression && output_headers.contains_key(ACCEPT_ENCODING) { - decompression_marker = Some(DecompressIfPossible::default()); - } - tracing::trace!( ua_kind = %profile.ua_kind, ua_version = ?profile.ua_version, @@ -317,8 +318,17 @@ where .map_err(Into::into)? .into_response(); - if let Some(marker) = decompression_marker { - res.extensions_mut().insert(marker); + if let Some(original_requested_encodings) = original_requested_encodings { + if let Some(content_encoding) = + Encoding::maybe_from_content_encoding_header(res.headers(), true) + { + if !original_requested_encodings.contains(&content_encoding) { + // Only request decompression if the server used a content-encoding + // not listed in the original request's Accept-Encoding header + // or because the callee didn't set this header at all. + res.extensions_mut().insert(DecompressIfPossible::default()); + } + } } Ok(res) @@ -505,7 +515,7 @@ fn merge_http_headers( let base_header_name = base_name.header_name(); let original_value = original_headers.remove(base_header_name); match base_header_name { - &ACCEPT | &ACCEPT_LANGUAGE | &CONTENT_TYPE | &ACCEPT_ENCODING => { + &ACCEPT | &ACCEPT_LANGUAGE | &CONTENT_TYPE => { let value = original_value.unwrap_or(base_value); output_headers_ref.push((base_name, value)); } From ce67607323d8bf1e6f7f101fa695bd970745bba8 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Mon, 24 Feb 2025 11:49:11 +0100 Subject: [PATCH 32/39] add initial client hints support --- rama-http-types/src/headers/client_hints.rs | 262 ++++++++++++++++++ rama-http-types/src/headers/mod.rs | 3 + rama-http/src/layer/ua.rs | 11 + rama-ua/src/emulate/service.rs | 155 ++++++++++- rama-ua/src/ua/info.rs | 38 +++ rama-ua/src/ua/mod.rs | 3 + rama-ua/src/ua/parse.rs | 5 + .../http_user_agent_classifier.rs | 1 + 8 files changed, 470 insertions(+), 8 deletions(-) create mode 100644 rama-http-types/src/headers/client_hints.rs diff --git a/rama-http-types/src/headers/client_hints.rs b/rama-http-types/src/headers/client_hints.rs new file mode 100644 index 00000000..5dc88196 --- /dev/null +++ b/rama-http-types/src/headers/client_hints.rs @@ -0,0 +1,262 @@ +macro_rules! client_hint { + ( + #[doc = $ch_doc:literal] + pub enum ClientHint { + $( + #[doc = $doc:literal] + $name:ident($($str:literal),*), + )+ + } + ) => { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub enum ClientHint { + $( + #[doc = $doc] + $name, + )+ + } + + impl ClientHint { + #[doc = "Checks if the client hint is low entropy, meaning that it will be send by default."] + pub fn is_low_entropy(&self) -> bool { + matches!(self, Self::SaveData | Self::Ua | Self::Mobile | Self::Platform) + } + + #[inline] + #[doc = "Attempts to convert a `HeaderName` to a `ClientHint`."] + pub fn match_header_name(name: &crate::HeaderName) -> Option { + name.try_into().ok() + } + + #[doc = "Returns the preferred string representation of the client hint."] + pub fn as_str(&self) -> &'static str { + match self { + $( + Self::$name => { + const VARIANTS: &'static [&'static str] = &[$($str,)+]; + VARIANTS[0] + }, + )+ + } + } + } + + rama_utils::macros::error::static_str_error! { + /// Client Hint Parsing Error + pub struct ClientHintParsingError; + } + + impl TryFrom<&str> for ClientHint { + type Error = ClientHintParsingError; + + fn try_from(name: &str) -> Result { + rama_utils::macros::match_ignore_ascii_case_str! { + match (name) { + $( + $($str)|+ => Ok(Self::$name), + )+ + _ => Err(ClientHintParsingError), + } + } + } + } + + impl TryFrom for ClientHint { + type Error = ClientHintParsingError; + + fn try_from(name: String) -> Result { + Self::try_from(name.as_str()) + } + } + + impl TryFrom<$crate::HeaderName> for ClientHint { + type Error = ClientHintParsingError; + + fn try_from(name: $crate::HeaderName) -> Result { + Self::try_from(name.as_str()) + } + } + + impl TryFrom<&$crate::HeaderName> for ClientHint { + type Error = ClientHintParsingError; + + fn try_from(name: &$crate::HeaderName) -> Result { + Self::try_from(name.as_str()) + } + } + + impl std::str::FromStr for ClientHint { + type Err = ClientHintParsingError; + + #[inline] + fn from_str(s: &str) -> Result { + Self::try_from(s) + } + } + + impl std::fmt::Display for ClientHint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } + } + + impl serde::Serialize for ClientHint { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } + } + + impl<'de> serde::Deserialize<'de> for ClientHint { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + let s = >::deserialize(deserializer)?; + Self::try_from(s.as_ref()).map_err(D::Error::custom) + } + } + }; +} + +// NOTE: we are open to contributions to this module, +// e.g. in case you wish typed headers for each or some of these client hint headers, +// we gladly mentor and guide you in the process. + +client_hint! { + #[doc = "Client Hints are a set of HTTP Headers and a JavaScript API that allow web browsers to send detailed information about the client device and browser to web servers. They are designed to be a successor to User-Agent, and provide a standardized way for web servers to optimize content for the client without relying on unreliable user-agent string-based detection or browser fingerprinting techniques."] + pub enum ClientHint { + /// Sec-CH-UA represents a user agent's branding and version. + Ua("Sec-CH-UA"), + /// Sec-CH-UA-Full-Version represents the user agent's full version. + FullVersion("Sec-CH-UA-Full-Version"), + /// Sec-CH-UA-Full-Version-List represents the full version for each brand in its brands list. + FullVersionList("Sec-CH-UA-Full-Version-List"), + /// Sec-CH-UA-Platform represents the platform on which a given user agent is executing. + Platform("Sec-CH-UA-Platform"), + /// Sec-CH-UA-Platform-Version represents the platform version on which a given user agent is executing. + PlatformVersion("Sec-CH-UA-Platform-Version"), + /// Sec-CH-UA-Arch represents the architecture of the platform on which a given user agent is executing. + Arch("Sec-CH-UA-Arch"), + /// Sec-CH-UA-Bitness represents the bitness of the architecture of the platform on which a given user agent is executing. + Bitness("Sec-CH-UA-Bitness"), + /// Sec-CH-UA-WoW64 is used to detect whether or not a user agent binary is running in 32-bit mode on 64-bit Windows. + Wow64("Sec-CH-UA-WoW64"), + /// Sec-CH-UA-Model represents the device on which a given user agent is executing. + Model("Sec-CH-UA-Model"), + /// Sec-CH-UA-Mobile is used to detect whether or not a user agent prefers a «mobile» user experience. + Mobile("Sec-CH-UA-Mobile"), + /// Sec-CH-UA-Form-Factors represents the form-factors of a device, historically represented as a token in the User-Agent string. + FormFactor("Sec-CH-UA-Form-Factors"), + /// Sec-CH-Lang (or Lang) represents the user's language preference. + Lang("Sec-CH-Lang", "Lang"), + /// Sec-CH-Save-Data (or Save-Data) represents the user agent's preference for reduced data usage. + SaveData("Sec-CH-Save-Data", "Save-Data"), + /// Sec-CH-Width gives a server the layout width of the image. + Width("Sec-CH-Width"), + /// Sec-CH-Viewport-Width (or Viewport-Width) is the width of the user's viewport in CSS pixels. + ViewportWidth("Sec-CH-Viewport-Width", "Viewport-Width"), + /// Sec-CH-Viewport-Height represents the user-agent's current viewport height. + ViewportHeight("Sec-CH-Viewport-Height"), + /// Sec-CH-DPR (or DPR) reports the ratio of physical pixels to CSS pixels of the user's screen. + Dpr("Sec-CH-DPR", "DPR"), + /// Sec-CH-Device-Memory (or Device-Memory) reveals the approximate amount of memory the current device has in GiB. Because this information could be used to fingerprint users, the value of Device-Memory is intentionally coarse. Valid values are 0.25, 0.5, 1, 2, 4, and 8. + DeviceMemory("Sec-CH-Device-Memory", "Device-Memory"), + /// Sec-CH-RTT (or RTT) provides the approximate Round Trip Time, in milliseconds, on the application layer. The RTT hint, unlike transport layer RTT, includes server processing time. The value of RTT is rounded to the nearest 25 milliseconds to prevent fingerprinting. + Rtt("Sec-CH-RTT", "RTT"), + /// Sec-CH-Downlink (or Downlink) expressed in megabits per second (Mbps), reveals the approximate downstream speed of the user's connection. The value is rounded to the nearest multiple of 25 kilobits per second. Because again, fingerprinting. + Downlink("Sec-CH-Downlink", "Downlink"), + /// Sec-CH-ECT (or ECT) stands for Effective Connection Type. Its value is one of an enumerated list of connection types, each of which describes a connection within specified ranges of both RTT and Downlink values. Valid values for ECT are 4g, 3g, 2g, and slow-2g. + Ect("Sec-CH-ECT", "ECT"), + /// Sec-CH-Prefers-Color-Scheme represents the user's preferred color scheme. + PrefersColorScheme("Sec-CH-Prefers-Color-Scheme"), + /// Sec-CH-Prefers-Reduced-Motion is used to detect if the user has requested the system minimize the amount of animation or motion it uses. + PrefersReducedMotion("Sec-CH-Prefers-Reduced-Motion"), + /// Sec-CH-Prefers-Reduced-Transparency is used to detect if the user has requested the system minimize the amount of transparent or translucent layer effects it uses. + PrefersReducedTransparency("Sec-CH-Prefers-Reduced-Transparency"), + /// Sec-CH-Prefers-Contrast is used to detect if the user has requested that the web content is presented with a higher (or lower) contrast. + PrefersContrast("Sec-CH-Prefers-Contrast"), + /// Sec-CH-Forced-Colors is used to detect if the user agent has enabled a forced colors mode where it enforces a user-chosen limited color palette on the page. + ForcedColors("Sec-CH-Forced-Colors"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_hint_ua_from_str() { + let hint = ClientHint::try_from("Sec-CH-UA").unwrap(); + assert_eq!(hint, ClientHint::Ua); + } + + #[test] + fn test_client_hint_ua_from_str_lowercase() { + let hint = ClientHint::try_from("sec-ch-ua").unwrap(); + assert_eq!(hint, ClientHint::Ua); + } + + #[test] + fn test_client_hint_ua_from_str_uppercase() { + let hint = ClientHint::try_from("SEC-CH-UA").unwrap(); + assert_eq!(hint, ClientHint::Ua); + } + + #[test] + fn test_client_hint_ua_from_str_mixedcase() { + let hint = ClientHint::try_from("Sec-CH-UA").unwrap(); + assert_eq!(hint, ClientHint::Ua); + } + + #[test] + fn test_client_hint_low_entropy() { + let hints = [ + "Sec-CH-UA", + "Sec-CH-UA-Mobile", + "Sec-CH-UA-Platform", + "Save-Data", + "Sec-CH-Save-Data", + ]; + + for hint in hints { + let hint = ClientHint::try_from(hint).expect(hint); + assert!(hint.is_low_entropy()); + } + } + + #[test] + fn test_client_hint_high_entropy() { + let hints = [ + "Sec-CH-UA-Full-Version", + "Sec-CH-UA-Full-Version-List", + "Sec-CH-UA-Platform-Version", + "Sec-CH-UA-Arch", + "Sec-CH-UA-Bitness", + "Sec-CH-UA-WoW64", + "Sec-CH-UA-Model", + "Sec-CH-UA-Form-Factors", + "Sec-CH-Width", + "Sec-CH-Viewport-Width", + "Sec-CH-Viewport-Height", + "Sec-CH-DPR", + "Sec-CH-Device-Memory", + "Sec-CH-RTT", + "Sec-CH-Downlink", + "Sec-CH-ECT", + "Sec-CH-Prefers-Color-Scheme", + "Sec-CH-Prefers-Reduced-Motion", + "Sec-CH-Prefers-Reduced-Transparency", + "Sec-CH-Prefers-Contrast", + "Sec-CH-Forced-Colors", + ]; + + for hint in hints { + let hint = ClientHint::try_from(hint).expect(hint); + assert!(!hint.is_low_entropy()); + } + } +} diff --git a/rama-http-types/src/headers/mod.rs b/rama-http-types/src/headers/mod.rs index 5b78013d..f0633571 100644 --- a/rama-http-types/src/headers/mod.rs +++ b/rama-http-types/src/headers/mod.rs @@ -98,3 +98,6 @@ pub mod util; mod common; pub use common::Accept; + +mod client_hints; +pub use client_hints::ClientHint; diff --git a/rama-http/src/layer/ua.rs b/rama-http/src/layer/ua.rs index bbf45b65..93199abd 100644 --- a/rama-http/src/layer/ua.rs +++ b/rama-http/src/layer/ua.rs @@ -154,6 +154,9 @@ where if let Some(req_init) = overwrites.req_init { ua.set_request_initiator(req_init); } + if let Some(req_client_hints) = overwrites.req_client_hints { + ua.set_requested_client_hints(req_client_hints); + } } ctx.insert(ua); @@ -208,8 +211,10 @@ mod tests { use crate::layer::required_header::AddRequiredRequestHeadersLayer; use crate::service::client::HttpClientExt; use crate::{IntoResponse, Response, StatusCode, headers}; + use itertools::Itertools; use rama_core::Context; use rama_core::service::service_fn; + use rama_http_types::headers::ClientHint; use rama_ua::RequestInitiator; use std::convert::Infallible; @@ -307,6 +312,10 @@ mod tests { assert_eq!(ua_info.kind, UserAgentKind::Chromium); assert_eq!(ua_info.version, Some(124)); assert_eq!(ua.platform(), Some(PlatformKind::Windows)); + assert_eq!( + ua.requested_client_hints().join(", "), + "Sec-CH-Downlink, Sec-CH-ECT" + ); Ok(StatusCode::OK.into_response()) } @@ -321,6 +330,7 @@ mod tests { "x-proxy-ua", serde_html_form::to_string(&UserAgentOverwrites { ua: Some(UA.to_owned()), + req_client_hints: Some(vec![ClientHint::Downlink, ClientHint::Ect]), ..Default::default() }) .unwrap(), @@ -362,6 +372,7 @@ mod tests { tls: Some(TlsAgent::Boringssl), preserve_ua: Some(true), req_init: Some(RequestInitiator::Xhr), + req_client_hints: Some(vec![ClientHint::Downlink]), }) .unwrap(), ) diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 30d6e406..252b732b 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -12,7 +12,10 @@ use rama_http_types::{ ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, HOST, ORIGIN, REFERER, USER_AGENT, }, - headers::encoding::{Encoding, parse_accept_encoding_headers}, + headers::{ + ClientHint, + encoding::{Encoding, parse_accept_encoding_headers}, + }, proto::{ h1::{ Http1HeaderMap, @@ -219,6 +222,10 @@ where Some(HttpAgent::Preserve), ); + let requested_client_hints = ctx + .get::() + .map(|ua| ua.requested_client_hints().copied().collect::>()); + let mut original_requested_encodings = None; if preserve_http { @@ -267,6 +274,7 @@ where original_headers, preserve_ua_header, is_secure_request, + requested_client_hints.as_deref(), ); tracing::trace!( @@ -502,6 +510,7 @@ fn merge_http_headers( original_headers: HeaderMap, preserve_ua_header: bool, is_secure_request: bool, + requested_client_hints: Option<&[ClientHint]>, ) -> Http1HeaderMap { let mut original_headers = HeaderMapValueRemover::from(original_headers); @@ -510,6 +519,18 @@ fn merge_http_headers( let mut output_headers_ref = &mut output_headers_a; + let is_header_allowed = |header_name: &HeaderName| { + if let Some(hint) = ClientHint::match_header_name(header_name) { + is_secure_request + && (hint.is_low_entropy() + || requested_client_hints + .map(|hints| hints.contains(&hint)) + .unwrap_or_default()) + } else { + is_secure_request || !starts_with_ignore_ascii_case(header_name.as_str(), "sec-fetch") + } + }; + // put all "base" headers in correct order, and with proper name casing for (base_name, base_value) in base_http_headers.clone().into_iter() { let base_header_name = base_name.header_name(); @@ -535,11 +556,7 @@ fn merge_http_headers( _ => { if base_header_name == CUSTOM_HEADER_MARKER { output_headers_ref = &mut output_headers_b; - } else if starts_with_ignore_ascii_case(base_header_name.as_str(), "sec-fetch") { - if is_secure_request { - output_headers_ref.push((base_name, base_value)); - } - } else { + } else if is_header_allowed(base_header_name) { output_headers_ref.push((base_name, base_value)); } } @@ -549,14 +566,20 @@ fn merge_http_headers( // respect original header order of original headers where possible for header_name in original_http_header_order.into_iter().flatten() { if let Some(value) = original_headers.remove(header_name.header_name()) { - output_headers_a.push((header_name, value)); + if is_header_allowed(header_name.header_name()) { + output_headers_a.push((header_name, value)); + } } } + let original_headers_iter = original_headers + .into_iter() + .filter(|(header_name, _)| is_header_allowed(header_name.header_name())); + Http1HeaderMap::from_iter( output_headers_a .into_iter() - .chain(original_headers) // add all remaining original headers in any order within the right loc + .chain(original_headers_iter) // add all remaining original headers in any order within the right loc .chain(output_headers_b), ) } @@ -584,6 +607,7 @@ mod tests { original_headers: Vec<(&'static str, &'static str)>, preserve_ua_header: bool, is_secure_request: bool, + requested_client_hints: Option>, expected: Vec<(&'static str, &'static str)>, } @@ -595,6 +619,7 @@ mod tests { original_headers: vec![], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![], }, TestCase { @@ -607,6 +632,7 @@ mod tests { original_headers: vec![], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![ ("Accept", "text/html"), ("Content-Type", "application/json"), @@ -619,6 +645,7 @@ mod tests { original_headers: vec![("accept", "text/html")], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![("accept", "text/html")], }, TestCase { @@ -628,6 +655,7 @@ mod tests { original_headers: vec![("content-type", "application/json")], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![ ("accept", "text/html"), ("user-agent", "python/3.10"), @@ -648,6 +676,7 @@ mod tests { ], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![ ("accept", "text/html"), ("content-type", "application/json"), @@ -668,6 +697,7 @@ mod tests { ], preserve_ua_header: true, is_secure_request: false, + requested_client_hints: None, expected: vec![ ("accept", "text/html"), ("content-type", "application/json"), @@ -689,6 +719,7 @@ mod tests { ], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![ ("accept", "text/html"), ("content-type", "application/json"), @@ -717,6 +748,7 @@ mod tests { ], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![ ("accept", "text/html"), ("cookie", "foo=bar"), @@ -749,6 +781,7 @@ mod tests { ], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![ ("accept", "text/html"), ("authorization", "Bearer 42"), @@ -783,6 +816,7 @@ mod tests { ], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![ ("accept", "text/html"), ("authorization", "Bearer 42"), @@ -835,6 +869,7 @@ mod tests { ], preserve_ua_header: false, is_secure_request: false, + requested_client_hints: None, expected: vec![ ("Host", "www.example.com"), ( @@ -900,6 +935,7 @@ mod tests { ], preserve_ua_header: false, is_secure_request: true, + requested_client_hints: None, expected: vec![ ( "User-Agent", @@ -926,6 +962,102 @@ mod tests { ("Priority", "u=0, i"), ], }, + TestCase { + description: "realistic browser example over tls with requested client hints", + base_http_headers: vec![ + ("Host", "www.google.com"), + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "en-US,en;q=0.9"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Referer", "https://www.google.com/"), + ("Upgrade-Insecure-Requests", "1"), + ("x-rama-custom-header-marker", "1"), + ("Cookie", "rama-ua-test=1"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("Sec-CH-Downlink", "100"), + ("Sec-CH-Ect", "4g"), + ("Sec-CH-RTT", "100"), + ("Sec-CH-UA-Arch", "arm"), + ("Sec-CH-UA-Bitness", "64"), + ("Sec-CH-UA-Full-Version", "120.0.0.0"), + ("Sec-CH-UA-Full-Version-List", "Chrome 120.0.0.0"), + ("Sec-CH-UA-Mobile", "?0"), + ("Sec-CH-UA-Platform", "macOS"), + ("Sec-CH-UA-Platform-Version", "10.15.7"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + original_http_header_order: Some(vec![ + "x-show-price", + "x-show-price-currency", + "accept-language", + "cookie", + "Sec-CH-UA-Model", + ]), + original_headers: vec![ + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("accept-language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("cookie", "session=on; foo=bar"), + ("x-requested-with", "XMLHttpRequest"), + ("sec-ch-ua-model", "Macintosh"), + ], + preserve_ua_header: false, + is_secure_request: true, + requested_client_hints: Some(vec![ + "Downlink", + "Ect", + "RTT", + "Sec-CH-UA-Arch", + "Sec-CH-UA-Bitness", + "Sec-CH-UA-Model", + ]), + expected: vec![ + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Upgrade-Insecure-Requests", "1"), + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("Sec-CH-UA-Model", "Macintosh"), + ("x-requested-with", "XMLHttpRequest"), + ("Cookie", "session=on; foo=bar"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("Sec-CH-Downlink", "100"), + ("Sec-CH-Ect", "4g"), + ("Sec-CH-RTT", "100"), + ("Sec-CH-UA-Arch", "arm"), + ("Sec-CH-UA-Bitness", "64"), + ("Sec-CH-UA-Mobile", "?0"), // not requested, but low entropy + ("Sec-CH-UA-Platform", "macOS"), // not requested, but low entropy + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + }, ]; for test_case in test_cases { @@ -955,6 +1087,12 @@ mod tests { ); let preserve_ua_header = test_case.preserve_ua_header; let is_secure_request = test_case.is_secure_request; + let requested_client_hints = test_case.requested_client_hints.map(|hints| { + hints + .into_iter() + .map(|hint| ClientHint::from_str(hint).unwrap()) + .collect::>() + }); let output_headers = merge_http_headers( &base_http_headers, @@ -962,6 +1100,7 @@ mod tests { original_headers, preserve_ua_header, is_secure_request, + requested_client_hints.as_deref(), ); let output_str = output_headers diff --git a/rama-ua/src/ua/info.rs b/rama-ua/src/ua/info.rs index c83b83c7..28d90168 100644 --- a/rama-ua/src/ua/info.rs +++ b/rama-ua/src/ua/info.rs @@ -1,5 +1,6 @@ use super::{RequestInitiator, parse_http_user_agent_header}; use rama_core::error::OpaqueError; +use rama_http_types::headers::ClientHint; use rama_utils::macros::match_ignore_ascii_case_str; use serde::{Deserialize, Deserializer, Serialize}; use std::{convert::Infallible, fmt, str::FromStr, sync::Arc}; @@ -15,6 +16,7 @@ pub struct UserAgent { pub(super) tls_agent_overwrite: Option, pub(super) preserve_ua_header: bool, pub(super) request_initiator: Option, + pub(super) requested_client_hints: Option>, } impl fmt::Display for UserAgent { @@ -130,6 +132,42 @@ impl UserAgent { self.request_initiator } + /// Define the requested (High-Entropy) Client Hints. + pub fn with_requested_client_hints(mut self, req_client_hints: Vec) -> Self { + self.requested_client_hints = Some(req_client_hints); + self + } + + /// Define the requested (High-Entropy) Client Hints. + pub fn set_requested_client_hints(&mut self, req_client_hints: Vec) -> &mut Self { + self.requested_client_hints = Some(req_client_hints); + self + } + + /// Append a requested (High-Entropy) Client Hint. + pub fn append_requested_client_hint(&mut self, hint: ClientHint) -> &mut Self { + self.requested_client_hints + .get_or_insert_default() + .push(hint); + self + } + + /// Extend the requested (High-Entropy) Client Hints. + pub fn extend_requested_client_hints( + &mut self, + hints: impl IntoIterator, + ) -> &mut Self { + self.requested_client_hints + .get_or_insert_default() + .extend(hints); + self + } + + /// returns the requested (High-Entropy) Client Hints. + pub fn requested_client_hints(&self) -> impl Iterator { + self.requested_client_hints.iter().flatten() + } + /// returns the `User-Agent` (header) value used by the [`UserAgent`]. pub fn header_str(&self) -> &str { &self.header diff --git a/rama-ua/src/ua/mod.rs b/rama-ua/src/ua/mod.rs index e869d044..c0ccb0ed 100644 --- a/rama-ua/src/ua/mod.rs +++ b/rama-ua/src/ua/mod.rs @@ -1,4 +1,5 @@ use rama_core::error::OpaqueError; +use rama_http_types::headers::ClientHint; use rama_utils::macros::match_ignore_ascii_case_str; use serde::{Deserialize, Deserializer, Serialize}; use std::{fmt, str::FromStr}; @@ -31,6 +32,8 @@ pub struct UserAgentOverwrites { pub tls: Option, /// Preserve the original [`UserAgent`] header of the http `Request`. pub preserve_ua: Option, + /// Requested (High-Entropy) Client Hints. + pub req_client_hints: Option>, /// Hint a specific request intiator for UA Emulation. A related /// or default initiator might be chosen in case the hinted one is not available. /// diff --git a/rama-ua/src/ua/parse.rs b/rama-ua/src/ua/parse.rs index 344b649b..c05abb03 100644 --- a/rama-ua/src/ua/parse.rs +++ b/rama-ua/src/ua/parse.rs @@ -36,6 +36,7 @@ pub(crate) fn parse_http_user_agent_header(header: impl Into>) -> UserA tls_agent_overwrite: None, preserve_ua_header: false, request_initiator: None, + requested_client_hints: None, }; } } @@ -134,6 +135,7 @@ pub(crate) fn parse_http_user_agent_header(header: impl Into>) -> UserA tls_agent_overwrite: None, preserve_ua_header: false, request_initiator: None, + requested_client_hints: None, }, (None, _, Some(platform), _) => UserAgent { header, @@ -142,6 +144,7 @@ pub(crate) fn parse_http_user_agent_header(header: impl Into>) -> UserA tls_agent_overwrite: None, preserve_ua_header: false, request_initiator: None, + requested_client_hints: None, }, (None, _, None, Some(device)) => UserAgent { header, @@ -150,6 +153,7 @@ pub(crate) fn parse_http_user_agent_header(header: impl Into>) -> UserA tls_agent_overwrite: None, preserve_ua_header: false, request_initiator: None, + requested_client_hints: None, }, (None, _, None, None) => UserAgent { header, @@ -158,6 +162,7 @@ pub(crate) fn parse_http_user_agent_header(header: impl Into>) -> UserA tls_agent_overwrite: None, preserve_ua_header: false, request_initiator: None, + requested_client_hints: None, }, } } diff --git a/tests/integration/examples/example_tests/http_user_agent_classifier.rs b/tests/integration/examples/example_tests/http_user_agent_classifier.rs index 4a28659f..8e260e7e 100644 --- a/tests/integration/examples/example_tests/http_user_agent_classifier.rs +++ b/tests/integration/examples/example_tests/http_user_agent_classifier.rs @@ -69,6 +69,7 @@ async fn test_http_user_agent_classifier() { tls: Some(TlsAgent::Boringssl), preserve_ua: Some(false), req_init: None, + req_client_hints: None, }) .unwrap(), ) From 7f0fccbfa986f0111037cfc89d094eb2514c25b2 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Mon, 24 Feb 2025 14:14:58 +0100 Subject: [PATCH 33/39] ensure rama-fp requests _all_ Client-Hints that we are aware of --- Cargo.lock | 1 + rama-cli/Cargo.toml | 1 + rama-cli/src/cmd/fp/mod.rs | 32 ++++++--------------- rama-http-types/src/headers/client_hints.rs | 14 +++++++++ rama-http-types/src/headers/mod.rs | 4 ++- rama-http/src/headers/mod.rs | 12 ++++++++ 6 files changed, 40 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 25fe2070..48bd993c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2205,6 +2205,7 @@ dependencies = [ "bytes", "clap", "hex", + "itertools 0.14.0", "mimalloc", "rama", "serde", diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 3fa85d8b..73a9ff33 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -23,6 +23,7 @@ base64 = { workspace = true } bytes = { workspace = true } clap = { workspace = true } hex = { workspace = true } +itertools = { workspace = true } jemallocator = { workspace = true, optional = true } mimalloc = { workspace = true, optional = true } rama = { version = "0.2.0-alpha.7", path = "..", features = ["full"] } diff --git a/rama-cli/src/cmd/fp/mod.rs b/rama-cli/src/cmd/fp/mod.rs index 4fd9cc51..f4c65b13 100644 --- a/rama-cli/src/cmd/fp/mod.rs +++ b/rama-cli/src/cmd/fp/mod.rs @@ -3,13 +3,17 @@ use base64::Engine; use base64::engine::general_purpose::STANDARD as ENGINE; use clap::Args; +use itertools::Itertools; use rama::{ cli::ForwardKind, combinators::Either7, error::{BoxError, OpaqueError}, http::{ HeaderName, HeaderValue, IntoResponse, - headers::{CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp}, + headers::{ + CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp, + client_hints::all_client_hint_header_name_strings, + }, layer::{ catch_panic::CatchPanicLayer, compression::CompressionLayer, forwarded::GetForwardedHeadersLayer, required_header::AddRequiredResponseHeadersLayer, @@ -199,28 +203,10 @@ pub async fn run(cfg: CliCommandFingerprint) -> Result<(), BoxError> { }; let address = format!("{}:{}", cfg.interface, cfg.port); - let ch_headers = [ - "Width", - "Downlink", - "Sec-CH-UA", - "Sec-CH-UA-Mobile", - "Sec-CH-UA-Full-Version", - "ETC", - "Save-Data", - "Sec-CH-UA-Platform", - "Sec-CH-Prefers-Reduced-Motion", - "Sec-CH-UA-Arch", - "Sec-CH-UA-Bitness", - "Sec-CH-UA-Model", - "Sec-CH-UA-Platform-Version", - "Sec-CH-UA-Prefers-Color-Scheme", - "Device-Memory", - "RTT", - "Sec-GPC", - ] - .join(", ") - .parse::() - .expect("parse header value"); + let ch_headers = all_client_hint_header_name_strings() + .join(", ") + .parse::() + .expect("parse header value"); graceful.spawn_task_fn(move |guard| async move { let inner_http_service = HijackLayer::new( diff --git a/rama-http-types/src/headers/client_hints.rs b/rama-http-types/src/headers/client_hints.rs index 5dc88196..472083a9 100644 --- a/rama-http-types/src/headers/client_hints.rs +++ b/rama-http-types/src/headers/client_hints.rs @@ -119,6 +119,20 @@ macro_rules! client_hint { Self::try_from(s.as_ref()).map_err(D::Error::custom) } } + + #[doc = "Returns an iterator over all client hint header name strings."] + pub fn all_client_hint_header_name_strings() -> impl Iterator { + [ + $( + $($str,)+ + )+ + ].into_iter() + } + + #[doc = "Returns an iterator over all client hint header names."] + pub fn all_client_hint_header_names() -> impl Iterator { + all_client_hint_header_name_strings().map($crate::HeaderName::from_static) + } }; } diff --git a/rama-http-types/src/headers/mod.rs b/rama-http-types/src/headers/mod.rs index f0633571..05bb13f3 100644 --- a/rama-http-types/src/headers/mod.rs +++ b/rama-http-types/src/headers/mod.rs @@ -100,4 +100,6 @@ mod common; pub use common::Accept; mod client_hints; -pub use client_hints::ClientHint; +pub use client_hints::{ + ClientHint, all_client_hint_header_name_strings, all_client_hint_header_names, +}; diff --git a/rama-http/src/headers/mod.rs b/rama-http/src/headers/mod.rs index ef8e3e0b..543145c9 100644 --- a/rama-http/src/headers/mod.rs +++ b/rama-http/src/headers/mod.rs @@ -85,6 +85,18 @@ pub use rama_http_types::headers::specifier::{Quality, QualityValue}; #[doc(inline)] pub use rama_http_types::headers::util; +#[doc(inline)] +pub use rama_http_types::headers::ClientHint; + +pub mod client_hints { + //! Http (UA) Client Hints + + #[doc(inline)] + pub use rama_http_types::headers::{ + all_client_hint_header_name_strings, all_client_hint_header_names, + }; +} + mod forwarded; #[doc(inline)] pub use forwarded::{ From 4dc6c121562eecef64bcdebb81aaf731a4eec5a4 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Tue, 25 Feb 2025 13:50:55 +0100 Subject: [PATCH 34/39] fix client hint header name iterator errors header names => lowercase --- rama-http-types/src/headers/client_hints.rs | 68 +++++++++++++-------- rama-http/src/layer/ua.rs | 2 +- 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/rama-http-types/src/headers/client_hints.rs b/rama-http-types/src/headers/client_hints.rs index 472083a9..b85a56d4 100644 --- a/rama-http-types/src/headers/client_hints.rs +++ b/rama-http-types/src/headers/client_hints.rs @@ -144,57 +144,57 @@ client_hint! { #[doc = "Client Hints are a set of HTTP Headers and a JavaScript API that allow web browsers to send detailed information about the client device and browser to web servers. They are designed to be a successor to User-Agent, and provide a standardized way for web servers to optimize content for the client without relying on unreliable user-agent string-based detection or browser fingerprinting techniques."] pub enum ClientHint { /// Sec-CH-UA represents a user agent's branding and version. - Ua("Sec-CH-UA"), + Ua("sec-ch-ua"), /// Sec-CH-UA-Full-Version represents the user agent's full version. - FullVersion("Sec-CH-UA-Full-Version"), + FullVersion("sec-ch-ua-full-version"), /// Sec-CH-UA-Full-Version-List represents the full version for each brand in its brands list. - FullVersionList("Sec-CH-UA-Full-Version-List"), + FullVersionList("sec-ch-ua-full-version-list"), /// Sec-CH-UA-Platform represents the platform on which a given user agent is executing. - Platform("Sec-CH-UA-Platform"), + Platform("sec-ch-ua-platform"), /// Sec-CH-UA-Platform-Version represents the platform version on which a given user agent is executing. - PlatformVersion("Sec-CH-UA-Platform-Version"), + PlatformVersion("sec-ch-ua-platform-version"), /// Sec-CH-UA-Arch represents the architecture of the platform on which a given user agent is executing. - Arch("Sec-CH-UA-Arch"), + Arch("sec-ch-ua-arch"), /// Sec-CH-UA-Bitness represents the bitness of the architecture of the platform on which a given user agent is executing. - Bitness("Sec-CH-UA-Bitness"), + Bitness("sec-ch-ua-bitness"), /// Sec-CH-UA-WoW64 is used to detect whether or not a user agent binary is running in 32-bit mode on 64-bit Windows. - Wow64("Sec-CH-UA-WoW64"), + Wow64("sec-ch-ua-wow64"), /// Sec-CH-UA-Model represents the device on which a given user agent is executing. - Model("Sec-CH-UA-Model"), + Model("sec-ch-ua-model"), /// Sec-CH-UA-Mobile is used to detect whether or not a user agent prefers a «mobile» user experience. - Mobile("Sec-CH-UA-Mobile"), + Mobile("sec-ch-ua-mobile"), /// Sec-CH-UA-Form-Factors represents the form-factors of a device, historically represented as a token in the User-Agent string. - FormFactor("Sec-CH-UA-Form-Factors"), + FormFactor("sec-ch-ua-form-factors"), /// Sec-CH-Lang (or Lang) represents the user's language preference. - Lang("Sec-CH-Lang", "Lang"), + Lang("sec-ch-lang", "lang"), /// Sec-CH-Save-Data (or Save-Data) represents the user agent's preference for reduced data usage. - SaveData("Sec-CH-Save-Data", "Save-Data"), + SaveData("sec-ch-save-data", "save-data"), /// Sec-CH-Width gives a server the layout width of the image. - Width("Sec-CH-Width"), + Width("sec-ch-width"), /// Sec-CH-Viewport-Width (or Viewport-Width) is the width of the user's viewport in CSS pixels. - ViewportWidth("Sec-CH-Viewport-Width", "Viewport-Width"), + ViewportWidth("sec-ch-viewport-width", "viewport-width"), /// Sec-CH-Viewport-Height represents the user-agent's current viewport height. - ViewportHeight("Sec-CH-Viewport-Height"), + ViewportHeight("sec-ch-viewport-height"), /// Sec-CH-DPR (or DPR) reports the ratio of physical pixels to CSS pixels of the user's screen. - Dpr("Sec-CH-DPR", "DPR"), + Dpr("sec-ch-dpr", "dpr"), /// Sec-CH-Device-Memory (or Device-Memory) reveals the approximate amount of memory the current device has in GiB. Because this information could be used to fingerprint users, the value of Device-Memory is intentionally coarse. Valid values are 0.25, 0.5, 1, 2, 4, and 8. - DeviceMemory("Sec-CH-Device-Memory", "Device-Memory"), + DeviceMemory("sec-ch-device-memory", "device-memory"), /// Sec-CH-RTT (or RTT) provides the approximate Round Trip Time, in milliseconds, on the application layer. The RTT hint, unlike transport layer RTT, includes server processing time. The value of RTT is rounded to the nearest 25 milliseconds to prevent fingerprinting. - Rtt("Sec-CH-RTT", "RTT"), + Rtt("sec-ch-rtt", "rtt"), /// Sec-CH-Downlink (or Downlink) expressed in megabits per second (Mbps), reveals the approximate downstream speed of the user's connection. The value is rounded to the nearest multiple of 25 kilobits per second. Because again, fingerprinting. - Downlink("Sec-CH-Downlink", "Downlink"), + Downlink("sec-ch-downlink", "downlink"), /// Sec-CH-ECT (or ECT) stands for Effective Connection Type. Its value is one of an enumerated list of connection types, each of which describes a connection within specified ranges of both RTT and Downlink values. Valid values for ECT are 4g, 3g, 2g, and slow-2g. - Ect("Sec-CH-ECT", "ECT"), + Ect("sec-ch-ect", "ect"), /// Sec-CH-Prefers-Color-Scheme represents the user's preferred color scheme. - PrefersColorScheme("Sec-CH-Prefers-Color-Scheme"), + PrefersColorScheme("sec-ch-prefers-color-scheme"), /// Sec-CH-Prefers-Reduced-Motion is used to detect if the user has requested the system minimize the amount of animation or motion it uses. - PrefersReducedMotion("Sec-CH-Prefers-Reduced-Motion"), + PrefersReducedMotion("sec-ch-prefers-reduced-motion"), /// Sec-CH-Prefers-Reduced-Transparency is used to detect if the user has requested the system minimize the amount of transparent or translucent layer effects it uses. - PrefersReducedTransparency("Sec-CH-Prefers-Reduced-Transparency"), + PrefersReducedTransparency("sec-ch-prefers-reduced-transparency"), /// Sec-CH-Prefers-Contrast is used to detect if the user has requested that the web content is presented with a higher (or lower) contrast. - PrefersContrast("Sec-CH-Prefers-Contrast"), + PrefersContrast("sec-ch-prefers-contrast"), /// Sec-CH-Forced-Colors is used to detect if the user agent has enabled a forced colors mode where it enforces a user-chosen limited color palette on the page. - ForcedColors("Sec-CH-Forced-Colors"), + ForcedColors("sec-ch-forced-colors"), } } @@ -273,4 +273,20 @@ mod tests { assert!(!hint.is_low_entropy()); } } + + #[test] + fn test_all_client_hint_header_name_strings_contains_some_hints() { + let strings = all_client_hint_header_name_strings().collect::>(); + assert!(strings.contains(&"sec-ch-ua"), "{:?}", strings); + } + + #[test] + fn test_all_client_hint_header_names() { + let names = all_client_hint_header_names().collect::>(); + let strings = all_client_hint_header_name_strings().collect::>(); + assert_eq!(names.len(), strings.len()); + for (name, string) in names.iter().zip(strings.iter()) { + assert_eq!(name.as_str(), *string); + } + } } diff --git a/rama-http/src/layer/ua.rs b/rama-http/src/layer/ua.rs index 93199abd..84896ff7 100644 --- a/rama-http/src/layer/ua.rs +++ b/rama-http/src/layer/ua.rs @@ -314,7 +314,7 @@ mod tests { assert_eq!(ua.platform(), Some(PlatformKind::Windows)); assert_eq!( ua.requested_client_hints().join(", "), - "Sec-CH-Downlink, Sec-CH-ECT" + "sec-ch-downlink, sec-ch-ect" ); Ok(StatusCode::OK.into_response()) From 39fb9c3a42a415365bb244367882a58f48108ee9 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Tue, 25 Feb 2025 21:27:54 +0100 Subject: [PATCH 35/39] add fly config for rama-fp-pg --- rama-fp/infra/deployments/pg/fly.toml | 69 +++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 rama-fp/infra/deployments/pg/fly.toml diff --git a/rama-fp/infra/deployments/pg/fly.toml b/rama-fp/infra/deployments/pg/fly.toml new file mode 100644 index 00000000..a10c31cc --- /dev/null +++ b/rama-fp/infra/deployments/pg/fly.toml @@ -0,0 +1,69 @@ +# fly.toml app configuration file generated for rama-fp-pg on 2025-02-25T21:24:55+01:00 +# +# See https://fly.io/docs/reference/configuration/ for information about how to use this file. +# + +app = 'rama-fp-pg' +primary_region = 'lhr' + +[env] + PRIMARY_REGION = 'lhr' + +[[mounts]] + source = 'pg_data' + destination = '/data' + +[[services]] + protocol = 'tcp' + internal_port = 5432 + auto_start_machines = false + + [[services.ports]] + port = 5432 + handlers = ['pg_tls'] + + [services.concurrency] + type = 'connections' + hard_limit = 1000 + soft_limit = 1000 + +[[services]] + protocol = 'tcp' + internal_port = 5433 + auto_start_machines = false + + [[services.ports]] + port = 5433 + handlers = ['pg_tls'] + + [services.concurrency] + type = 'connections' + hard_limit = 1000 + soft_limit = 1000 + +[checks] + [checks.pg] + port = 5500 + type = 'http' + interval = '15s' + timeout = '10s' + path = '/flycheck/pg' + + [checks.role] + port = 5500 + type = 'http' + interval = '15s' + timeout = '10s' + path = '/flycheck/role' + + [checks.vm] + port = 5500 + type = 'http' + interval = '15s' + timeout = '10s' + path = '/flycheck/vm' + +[[metrics]] + port = 9187 + path = '/metrics' + https = false From 252618b033a0eb643f97e39e9e5714b1b5e6100a Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Fri, 28 Feb 2025 11:22:34 +0100 Subject: [PATCH 36/39] integrate storage logic (for now dummy) into rama-fp and adapt python+js code to it --- rama-cli/assets/script.js | 22 --- rama-cli/src/cmd/fp/data.rs | 164 +++++++++++++++--- rama-cli/src/cmd/fp/endpoints.rs | 145 +++++++--------- rama-cli/src/cmd/fp/mod.rs | 78 ++++++++- rama-cli/src/cmd/fp/state.rs | 27 ++- rama-cli/src/cmd/fp/storage.rs | 117 +++++++++++++ rama-fp/browserstack/main.py | 7 + rama-http-types/src/proto/h1/headers/map.rs | 6 + .../src/proto/h1/headers/original.rs | 4 + .../src/layer/auth/require_authorization.rs | 7 +- rama-ua/src/emulate/service.rs | 51 ++++-- rama-ua/src/profile/db.rs | 4 +- rama-ua/src/profile/http.rs | 12 +- 13 files changed, 481 insertions(+), 163 deletions(-) create mode 100644 rama-cli/src/cmd/fp/storage.rs diff --git a/rama-cli/assets/script.js b/rama-cli/assets/script.js index 1bcc1ef2..80b3f747 100644 --- a/rama-cli/assets/script.js +++ b/rama-cli/assets/script.js @@ -22,20 +22,6 @@ async function fetchWithBackoff(url, options) { throw new Error('Max retries exceeded'); } -// Function to make a GET request -async function makeGetRequest(url) { - const headers = { - 'x-RAMA-custom-header-marker': `rama-fp${Date.now()}`, - }; - - const options = { - method: 'GET', - headers - }; - - return fetchWithBackoff(url, options); -} - // Function to make a POST request async function makePostRequest(url, number) { const headers = { @@ -79,18 +65,10 @@ function makeRequestWithXHR(url, method, number) { // Main function to execute the requests async function main() { try { - // Fetch GET request - const response1 = await makeGetRequest('/api/fetch/number'); - const { number } = await response1.json(); - // Fetch POST request const response2 = await makePostRequest(`/api/fetch/number/${number}`, number); const result = await response2.json(); - // XMLHttpRequest GET request - const response3 = await makeRequestWithXHR('/api/xml/number', 'GET'); - const { number: number2 } = JSON.parse(response3); - // XMLHttpRequest POST request const response4 = await makeRequestWithXHR(`/api/xml/number/${number2}`, 'POST', number2); const result2 = JSON.parse(response4); diff --git a/rama-cli/src/cmd/fp/data.rs b/rama-cli/src/cmd/fp/data.rs index 424db958..d528c915 100644 --- a/rama-cli/src/cmd/fp/data.rs +++ b/rama-cli/src/cmd/fp/data.rs @@ -1,9 +1,9 @@ -use super::State; +use super::{State, StorageAuthorized}; use rama::{ Context, - error::{BoxError, ErrorContext}, + error::{BoxError, ErrorContext, OpaqueError}, http::{ - HeaderMap, Request, + self, HeaderMap, HeaderName, Request, dep::http::{Extensions, request::Parts}, headers::Forwarded, proto::{h1::Http1HeaderMap, h2::PseudoHeaderOrder}, @@ -17,7 +17,7 @@ use rama::{ SecureTransport, client::{ClientHello, ClientHelloExtension}, }, - ua::UserAgent, + ua::{Http1Settings, Http2Settings, UserAgent}, }; use serde::Serialize; use std::{str::FromStr, sync::Arc}; @@ -210,8 +210,118 @@ pub(super) struct HttpInfo { pub(super) pseudo_headers: Option>, } -pub(super) fn get_http_info(headers: HeaderMap, ext: &mut Extensions) -> HttpInfo { - let headers: Vec<_> = Http1HeaderMap::new(headers, Some(ext)) +pub(super) async fn get_and_store_http_info( + ctx: &Context>, + headers: HeaderMap, + ext: &mut Extensions, + http_version: http::Version, + ua: String, + initiator: Initiator, +) -> Result { + let original_headers = Http1HeaderMap::new(headers, Some(ext)); + let pseudo_headers = ext.get::(); + + if ctx.contains::() { + if let Some(storage) = ctx.state().storage.as_ref() { + match http_version { + http::Version::HTTP_09 | http::Version::HTTP_10 | http::Version::HTTP_11 => { + match initiator { + Initiator::Navigator => { + storage + .store_h1_headers_navigate(ua, original_headers.clone()) + .await + .context("store h1 headers navigate")?; + } + Initiator::Fetch => { + storage + .store_h1_headers_fetch(ua, original_headers.clone()) + .await + .context("store h1 headers fetch")?; + } + Initiator::XMLHttpRequest => { + if let Some(header_name) = original_headers.get_original_name( + &HeaderName::from_static("x-rama-custom-header-marker"), + ) { + // Check if the header name is title-cased or not + let header_str = header_name.as_str(); + let title_case_headers = header_str.split('-').all(|part| { + part.chars().next().is_none_or(|c| c.is_ascii_uppercase()) + && part.chars().skip(1).all(|c| c.is_ascii_lowercase()) + }); + + tracing::debug!( + "Custom header marker found: {}, title-cased: {}", + header_str, + title_case_headers + ); + + storage + .store_h1_settings( + ua.clone(), + Http1Settings { title_case_headers }, + ) + .await + .context("store h1 settings")?; + } + + storage + .store_h1_headers_xhr(ua, original_headers.clone()) + .await + .context("store h1 headers xhr")?; + } + Initiator::Form => { + storage + .store_h1_headers_form(ua, original_headers.clone()) + .await + .context("store h1 headers form")?; + } + } + } + http::Version::HTTP_2 => { + if let Some(pseudo_headers) = pseudo_headers { + storage + .store_h2_settings( + ua.clone(), + Http2Settings { + http_pseudo_headers: Some(pseudo_headers.iter().collect()), + }, + ) + .await + .context("store h2 pseudo headers")?; + } + match initiator { + Initiator::Navigator => { + storage + .store_h2_headers_navigate(ua, original_headers.clone()) + .await + .context("store h2 headers navigate")?; + } + Initiator::Fetch => { + storage + .store_h2_headers_fetch(ua, original_headers.clone()) + .await + .context("store h2 headers fetch")?; + } + Initiator::XMLHttpRequest => { + storage + .store_h2_headers_xhr(ua, original_headers.clone()) + .await + .context("store h2 headers xhr")?; + } + Initiator::Form => { + storage + .store_h2_headers_form(ua, original_headers.clone()) + .await + .context("store h2 headers form")?; + } + } + } + _ => (), + } + } + } + + let headers: Vec<_> = original_headers .into_iter() .map(|(name, value)| { ( @@ -223,14 +333,13 @@ pub(super) fn get_http_info(headers: HeaderMap, ext: &mut Extensions) -> HttpInf }) .collect(); - let pseudo_headers: Option> = ext - .get::() - .map(|o| o.iter().map(|p| p.to_string()).collect()); + let pseudo_headers: Option> = + pseudo_headers.map(|o| o.iter().map(|p| p.to_string()).collect()); - HttpInfo { + Ok(HttpInfo { headers, pseudo_headers, - } + }) } #[derive(Debug, Clone, Serialize)] @@ -267,20 +376,31 @@ pub(super) enum TlsDisplayInfoExtensionData { Multi(Vec), } -pub(super) fn get_tls_display_info(ctx: &Context>) -> Option { - let hello: &ClientHello = ctx +pub(super) async fn get_tls_display_info_and_store( + ctx: &Context>, + ua: String, +) -> Result, OpaqueError> { + let hello: &ClientHello = match ctx .get::() - .and_then(|st| st.client_hello())?; + .and_then(|st| st.client_hello()) + { + Some(hello) => hello, + None => return Ok(None), + }; - let ja4 = Ja4::compute(ctx.extensions()) - .inspect_err(|err| tracing::error!(?err, "ja4 compute failure")) - .ok()?; + if ctx.contains::() { + if let Some(storage) = ctx.state().storage.as_ref() { + storage + .store_tls_client_hello(ua, hello.clone()) + .await + .context("store tls client hello")?; + } + } - let ja3 = Ja3::compute(ctx.extensions()) - .inspect_err(|err| tracing::error!(?err, "ja3 compute failure")) - .ok()?; + let ja4 = Ja4::compute(ctx.extensions()).context("ja4 compute")?; + let ja3 = Ja3::compute(ctx.extensions()).context("ja3 compute")?; - Some(TlsDisplayInfo { + Ok(Some(TlsDisplayInfo { ja4: Ja4DisplayInfo { full: format!("{ja4:?}"), hash: format!("{ja4}"), @@ -355,5 +475,5 @@ pub(super) fn get_tls_display_info(ctx: &Context>) -> Option>(), - }) + })) } diff --git a/rama-cli/src/cmd/fp/endpoints.rs b/rama-cli/src/cmd/fp/endpoints.rs index 8076bf2c..6c94282a 100644 --- a/rama-cli/src/cmd/fp/endpoints.rs +++ b/rama-cli/src/cmd/fp/endpoints.rs @@ -2,7 +2,8 @@ use super::{ State, data::{ DataSource, FetchMode, Initiator, RequestInfo, ResourceType, TlsDisplayInfo, UserAgentInfo, - get_http_info, get_ja4h_info, get_request_info, get_tls_display_info, get_user_agent_info, + get_and_store_http_info, get_ja4h_info, get_request_info, get_tls_display_info_and_store, + get_user_agent_info, }, }; use crate::cmd::fp::data::TlsDisplayInfoExtensionData; @@ -95,7 +96,18 @@ pub(super) async fn get_report( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let http_info = get_http_info(parts.headers, &mut parts.extensions); + let user_agent = user_agent_info.user_agent.clone(); + + let http_info = get_and_store_http_info( + &ctx, + parts.headers, + &mut parts.extensions, + parts.version, + user_agent.clone(), + Initiator::Navigator, + ) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; let head = r#""#.to_owned(); @@ -126,7 +138,10 @@ pub(super) async fn get_report( }); } - let tls_info = get_tls_display_info(&ctx); + let tls_info = get_tls_display_info_and_store(&ctx, user_agent) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + if let Some(tls_info) = tls_info { let mut tls_tables = tls_info.into(); tables.append(&mut tls_tables); @@ -175,45 +190,6 @@ pub(super) struct APINumberParams { number: usize, } -pub(super) async fn get_api_fetch_number( - mut ctx: Context>, - req: Request, -) -> Result, Response> { - let ja4h = get_ja4h_info(&req); - - let (mut parts, _) = req.into_parts(); - - let user_agent_info = get_user_agent_info(&ctx).await; - - let request_info = get_request_info( - FetchMode::SameOrigin, - ResourceType::Xhr, - Initiator::Fetch, - &mut ctx, - &parts, - ) - .await - .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - - let http_info = get_http_info(parts.headers, &mut parts.extensions); - - let tls_info = get_tls_display_info(&ctx); - - Ok(Json(json!({ - "number": ctx.state().counter.fetch_add(1, std::sync::atomic::Ordering::AcqRel), - "fp": { - "user_agent_info": user_agent_info, - "request_info": request_info, - "tls_info": tls_info, - "http_info": json!({ - "headers": http_info.headers, - "pseudo_headers": http_info.pseudo_headers, - "ja4h": ja4h, - }), - } - }))) -} - pub(super) async fn post_api_fetch_number( Path(params): Path, mut ctx: Context>, @@ -225,6 +201,8 @@ pub(super) async fn post_api_fetch_number( let user_agent_info = get_user_agent_info(&ctx).await; + let user_agent = user_agent_info.user_agent.clone(); + let request_info = get_request_info( FetchMode::SameOrigin, ResourceType::Xhr, @@ -235,51 +213,23 @@ pub(super) async fn post_api_fetch_number( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let http_info = get_http_info(parts.headers, &mut parts.extensions); - - let tls_info = get_tls_display_info(&ctx); - - Ok(Json(json!({ - "number": params.number, - "fp": { - "user_agent_info": user_agent_info, - "request_info": request_info, - "tls_info": tls_info, - "http_info": json!({ - "headers": http_info.headers, - "pseudo_headers": http_info.pseudo_headers, - "ja4h": ja4h, - }), - } - }))) -} - -pub(super) async fn get_api_xml_http_request_number( - mut ctx: Context>, - req: Request, -) -> Result, Response> { - let ja4h = get_ja4h_info(&req); - - let (mut parts, _) = req.into_parts(); - - let user_agent_info = get_user_agent_info(&ctx).await; - - let request_info = get_request_info( - FetchMode::SameOrigin, - ResourceType::Xhr, + let http_info = get_and_store_http_info( + &ctx, + parts.headers, + &mut parts.extensions, + parts.version, + user_agent.clone(), Initiator::Fetch, - &mut ctx, - &parts, ) .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let http_info = get_http_info(parts.headers, &mut parts.extensions); - - let tls_info = get_tls_display_info(&ctx); + let tls_info = get_tls_display_info_and_store(&ctx, user_agent) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; Ok(Json(json!({ - "number": ctx.state().counter.fetch_add(1, std::sync::atomic::Ordering::AcqRel), + "number": params.number, "fp": { "user_agent_info": user_agent_info, "request_info": request_info, @@ -304,6 +254,8 @@ pub(super) async fn post_api_xml_http_request_number( let user_agent_info = get_user_agent_info(&ctx).await; + let user_agent = user_agent_info.user_agent.clone(); + let request_info = get_request_info( FetchMode::SameOrigin, ResourceType::Xhr, @@ -314,9 +266,20 @@ pub(super) async fn post_api_xml_http_request_number( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let http_info = get_http_info(parts.headers, &mut parts.extensions); + let http_info = get_and_store_http_info( + &ctx, + parts.headers, + &mut parts.extensions, + parts.version, + user_agent.clone(), + Initiator::XMLHttpRequest, + ) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let tls_info = get_tls_display_info(&ctx); + let tls_info = get_tls_display_info_and_store(&ctx, user_agent) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; Ok(Json(json!({ "number": params.number, @@ -344,6 +307,8 @@ pub(super) async fn form(mut ctx: Context>, req: Request) -> Result>, req: Request) -> Result>, req: Request) -> Result Result<(), BoxError> { ) .layer(match_service!{ HttpMatcher::get("/report") => endpoints::get_report, - HttpMatcher::get("/api/fetch/number") => endpoints::get_api_fetch_number, HttpMatcher::post("/api/fetch/number/:number") => endpoints::post_api_fetch_number, - HttpMatcher::get("/api/xml/number") => endpoints::get_api_xml_http_request_number, HttpMatcher::post("/api/xml/number/:number") => endpoints::post_api_xml_http_request_number, HttpMatcher::method_get().or_method_post().and_path("/form") => endpoints::form, _ => Redirect::temporary("/consent"), @@ -236,6 +240,7 @@ pub async fn run(cfg: CliCommandFingerprint) -> Result<(), BoxError> { HeaderName::from_static("x-sponsored-by"), HeaderValue::from_static("fly.io"), ), + StorageAuthLayer, SetResponseHeaderLayer::if_not_present( HeaderName::from_static("accept-ch"), ch_headers.clone(), @@ -281,7 +286,10 @@ pub async fn run(cfg: CliCommandFingerprint) -> Result<(), BoxError> { }) ); - let tcp_listener = TcpListener::build_with_state(Arc::new(State::new(acme_data))) + let pg_url = std::env::var("DATABASE_URL").ok(); + let storage_auth = std::env::var("RAMA_FP_STORAGE_COOKIE").ok(); + + let tcp_listener = TcpListener::build_with_state(Arc::new(State::new(acme_data, pg_url.as_deref(), storage_auth.as_deref()).await.expect("create state"))) .bind(&address) .await .expect("bind TCP Listener"); @@ -351,3 +359,63 @@ impl FromStr for HttpVersion { }) } } + +#[derive(Debug, Clone, Default)] +struct StorageAuthLayer; + +impl Layer for StorageAuthLayer { + type Service = StorageAuthService; + + fn layer(&self, inner: S) -> Self::Service { + StorageAuthService { inner } + } +} + +struct StorageAuthService { + inner: S, +} + +impl std::fmt::Debug for StorageAuthService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StorageAuthService") + .field("inner", &self.inner) + .finish() + } +} + +impl Service, Request> for StorageAuthService +where + Body: Send + 'static, + S: Service, Request>, +{ + type Response = S::Response; + type Error = S::Error; + + async fn serve( + &self, + mut ctx: Context>, + mut req: Request, + ) -> Result { + if let Some(cookie) = req.headers().typed_get::() { + let cookie = cookie + .iter() + .map(|(k, v)| { + if k.eq_ignore_ascii_case("rama-storage-auth") { + if Some(v) == ctx.state().storage_auth.as_deref() { + ctx.insert(StorageAuthorized); + } + "rama-storage-auth=xxx".to_owned() + } else { + format!("{k}={v}") + } + }) + .join("; "); + if !cookie.is_empty() { + req.headers_mut() + .insert(COOKIE, HeaderValue::from_str(&cookie).unwrap()); + } + } + + self.inner.serve(ctx, req).await + } +} diff --git a/rama-cli/src/cmd/fp/state.rs b/rama-cli/src/cmd/fp/state.rs index 62f08e82..898d543d 100644 --- a/rama-cli/src/cmd/fp/state.rs +++ b/rama-cli/src/cmd/fp/state.rs @@ -1,23 +1,36 @@ -use std::{collections::HashMap, sync::atomic::AtomicUsize}; +use std::collections::HashMap; -use super::data::DataSource; +use rama::error::{ErrorContext, OpaqueError}; + +use super::{data::DataSource, storage::Storage}; #[derive(Debug)] #[non_exhaustive] pub(super) struct State { pub(super) data_source: DataSource, - pub(super) counter: AtomicUsize, pub(super) acme: ACMEData, + pub(super) storage: Option, + pub(super) storage_auth: Option, } impl State { /// Create a new instance of [`State`]. - pub(super) fn new(acme: ACMEData) -> Self { - State { + pub(super) async fn new( + acme: ACMEData, + pg_url: Option<&str>, + storage_auth: Option<&str>, + ) -> Result { + let storage = match pg_url { + Some(pg_url) => Some(Storage::new(pg_url).await.context("create storage")?), + None => None, + }; + + Ok(State { data_source: DataSource::default(), - counter: AtomicUsize::new(0), acme, - } + storage, + storage_auth: storage_auth.map(|s| s.to_owned()), + }) } } diff --git a/rama-cli/src/cmd/fp/storage.rs b/rama-cli/src/cmd/fp/storage.rs new file mode 100644 index 00000000..6556aac0 --- /dev/null +++ b/rama-cli/src/cmd/fp/storage.rs @@ -0,0 +1,117 @@ +use rama::{ + error::OpaqueError, + http::proto::h1::Http1HeaderMap, + net::tls::client::ClientHello, + ua::{Http1Settings, Http2Settings}, +}; + +#[derive(Debug, Clone)] +pub(super) struct Storage; + +impl Storage { + pub(super) async fn new(pg_url: &str) -> Result { + tracing::info!("create new storage with PG URL: {}", pg_url); + Ok(Self) + } +} + +impl Storage { + pub(super) async fn store_h1_settings( + &self, + ua: String, + settings: Http1Settings, + ) -> Result<(), OpaqueError> { + tracing::info!("store h1 settings for UA '{ua}': {settings:?}"); + Ok(()) + } + + pub(super) async fn store_h1_headers_navigate( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h1 headers for UA '{ua}': {headers:?}"); + Ok(()) + } + + pub(super) async fn store_h1_headers_fetch( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h1 headers for UA '{ua}': {headers:?}"); + Ok(()) + } + + pub(super) async fn store_h1_headers_xhr( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h1 headers for UA '{ua}': {headers:?}"); + Ok(()) + } + + pub(super) async fn store_h1_headers_form( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h1 headers for UA '{ua}': {headers:?}"); + Ok(()) + } + + pub(super) async fn store_h2_settings( + &self, + ua: String, + settings: Http2Settings, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 settings for UA '{ua}': {settings:?}"); + Ok(()) + } + + pub(super) async fn store_h2_headers_navigate( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 headers for UA '{ua}': {headers:?}"); + Ok(()) + } + + pub(super) async fn store_h2_headers_fetch( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 headers for UA '{ua}': {headers:?}"); + Ok(()) + } + + pub(super) async fn store_h2_headers_xhr( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 headers for UA '{ua}': {headers:?}"); + Ok(()) + } + + pub(super) async fn store_h2_headers_form( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 headers for UA '{ua}': {headers:?}"); + Ok(()) + } + + pub(super) async fn store_tls_client_hello( + &self, + ua: String, + tls_client_hello: ClientHello, + ) -> Result<(), OpaqueError> { + tracing::info!("store tls client hello for UA '{ua}': {tls_client_hello:?}"); + Ok(()) + } +} diff --git a/rama-fp/browserstack/main.py b/rama-fp/browserstack/main.py index 1a0dacd5..ebd39f45 100644 --- a/rama-fp/browserstack/main.py +++ b/rama-fp/browserstack/main.py @@ -70,6 +70,8 @@ def env(key): BROWSERSTACK_ACCESS_KEY = env("BROWSERSTACK_ACCESS_KEY") URL = os.environ.get("URL") or "https://hub.browserstack.com/wd/hub" +RAMA_FP_STORAGE_COOKIE = env("RAMA_FP_STORAGE_COOKIE") + def get_browser_option(browser): switcher = { @@ -154,6 +156,11 @@ def run_session(cap): options=options, ) + driver.set_cookie({ + 'name': 'rama-storage-auth', + 'value': RAMA_FP_STORAGE_COOKIE, + }) + driver.get(entrypoint) print("ua", driver.execute_script("return navigator.userAgent;")) print("loc", driver.execute_script("return document.location.href;")) diff --git a/rama-http-types/src/proto/h1/headers/map.rs b/rama-http-types/src/proto/h1/headers/map.rs index d4789804..e316e6e1 100644 --- a/rama-http-types/src/proto/h1/headers/map.rs +++ b/rama-http-types/src/proto/h1/headers/map.rs @@ -61,6 +61,12 @@ impl Http1HeaderMap { self.headers.get(key) } + pub fn get_original_name(&self, key: &HeaderName) -> Option<&Http1HeaderName> { + self.original_headers + .iter() + .find(|header| header.header_name() == key) + } + #[inline] pub fn contains_key(&self, key: impl AsHeaderName) -> bool { self.headers.contains_key(key) diff --git a/rama-http-types/src/proto/h1/headers/original.rs b/rama-http-types/src/proto/h1/headers/original.rs index 91cb7bb1..741d160e 100644 --- a/rama-http-types/src/proto/h1/headers/original.rs +++ b/rama-http-types/src/proto/h1/headers/original.rs @@ -35,6 +35,10 @@ impl OriginalHttp1Headers { pub fn is_empty(&self) -> bool { self.ordered_headers.is_empty() } + + pub fn iter(&self) -> impl Iterator { + self.ordered_headers.iter() + } } impl OriginalHttp1Headers { diff --git a/rama-http/src/layer/auth/require_authorization.rs b/rama-http/src/layer/auth/require_authorization.rs index d7eb1dfd..5110d2f1 100644 --- a/rama-http/src/layer/auth/require_authorization.rs +++ b/rama-http/src/layer/auth/require_authorization.rs @@ -64,7 +64,6 @@ use rama_core::Context; use std::{fmt, marker::PhantomData, sync::Arc}; use rama_net::user::UserId; -use sealed::AuthorizerSeal; const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; @@ -283,7 +282,7 @@ mod sealed { use super::*; /// Private trait that contains the actual authorization logic - pub(super) trait AuthorizerSeal: Send + Sync + 'static { + pub trait AuthorizerSeal: Send + Sync + 'static { /// Check if the given header value is valid for this authorizer. fn is_valid(&self, header_value: &HeaderValue) -> bool; @@ -293,7 +292,7 @@ mod sealed { impl AuthorizerSeal for Basic { fn is_valid(&self, header_value: &HeaderValue) -> bool { - header_value == &self.header_value + header_value == self.header_value } fn www_authenticate_header() -> Option { @@ -303,7 +302,7 @@ mod sealed { impl AuthorizerSeal for Bearer { fn is_valid(&self, header_value: &HeaderValue) -> bool { - header_value == &self.header_value + header_value == self.header_value } fn www_authenticate_header() -> Option { diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs index 252b732b..2895f144 100644 --- a/rama-ua/src/emulate/service.rs +++ b/rama-ua/src/emulate/service.rs @@ -357,7 +357,7 @@ fn emulate_http_settings( "UA emulation add http1-specific settings", ); ctx.insert(Http1ClientContextParams { - title_header_case: profile.http.h1.title_case_headers, + title_header_case: profile.http.h1.settings.title_case_headers, }); } Version::HTTP_2 => { @@ -368,7 +368,13 @@ fn emulate_http_settings( "UA emulation add h2-specific settings", ); req.extensions_mut().insert(PseudoHeaderOrder::from_iter( - profile.http.h2.http_pseudo_headers.iter(), + profile + .http + .h2 + .settings + .http_pseudo_headers + .iter() + .flatten(), )); } Version::HTTP_3 => tracing::debug!( @@ -536,11 +542,12 @@ fn merge_http_headers( let base_header_name = base_name.header_name(); let original_value = original_headers.remove(base_header_name); match base_header_name { - &ACCEPT | &ACCEPT_LANGUAGE | &CONTENT_TYPE => { + &ACCEPT | &ACCEPT_LANGUAGE => { let value = original_value.unwrap_or(base_value); output_headers_ref.push((base_name, value)); } - &REFERER | &COOKIE | &AUTHORIZATION | &HOST | &ORIGIN | &CONTENT_LENGTH => { + &REFERER | &COOKIE | &AUTHORIZATION | &HOST | &ORIGIN | &CONTENT_LENGTH + | &CONTENT_TYPE => { if let Some(value) = original_value { output_headers_ref.push((base_name, value)); } @@ -596,7 +603,9 @@ mod tests { Body, BodyExtractExt, HeaderValue, header::ETAG, proto::h1::Http1HeaderName, }; - use crate::{Http1Profile, Http2Profile, HttpHeadersProfile, HttpProfile}; + use crate::{ + Http1Profile, Http1Settings, Http2Profile, Http2Settings, HttpHeadersProfile, HttpProfile, + }; #[test] fn test_merge_http_headers() { @@ -633,10 +642,20 @@ mod tests { preserve_ua_header: false, is_secure_request: false, requested_client_hints: None, - expected: vec![ + expected: vec![("Accept", "text/html")], + }, + TestCase { + description: "base headers only with content-type", + base_http_headers: vec![ ("Accept", "text/html"), ("Content-Type", "application/json"), ], + original_http_header_order: None, + original_headers: vec![("content-type", "text/xml")], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![("Accept", "text/html"), ("Content-Type", "text/xml")], }, TestCase { description: "original headers only", @@ -1229,7 +1248,7 @@ mod tests { xhr: None, form: None, }, - title_case_headers: false, + settings: Http1Settings::default(), }, h2: Http2Profile { headers: HttpHeadersProfile { @@ -1243,7 +1262,7 @@ mod tests { xhr: None, form: None, }, - http_pseudo_headers: vec![], + settings: Http2Settings::default(), }, }, #[cfg(feature = "tls")] @@ -1304,7 +1323,7 @@ mod tests { fetch: None, form: None, }, - title_case_headers: false, + settings: Http1Settings::default(), }, h2: Http2Profile { headers: HttpHeadersProfile { @@ -1313,7 +1332,7 @@ mod tests { xhr: None, form: None, }, - http_pseudo_headers: vec![], + settings: Http2Settings::default(), }, }, #[cfg(feature = "tls")] @@ -1377,7 +1396,7 @@ mod tests { None, )), }, - title_case_headers: false, + settings: Http1Settings::default(), }, h2: Http2Profile { headers: HttpHeadersProfile { @@ -1386,7 +1405,7 @@ mod tests { xhr: None, form: None, }, - http_pseudo_headers: vec![], + settings: Http2Settings::default(), }, }, #[cfg(feature = "tls")] @@ -1450,7 +1469,7 @@ mod tests { None, )), }, - title_case_headers: false, + settings: Http1Settings::default(), }, h2: Http2Profile { headers: HttpHeadersProfile { @@ -1459,7 +1478,7 @@ mod tests { xhr: None, form: None, }, - http_pseudo_headers: vec![], + settings: Http2Settings::default(), }, }, #[cfg(feature = "tls")] @@ -1532,7 +1551,7 @@ mod tests { None, )), }, - title_case_headers: false, + settings: Http1Settings::default(), }, h2: Http2Profile { headers: HttpHeadersProfile { @@ -1541,7 +1560,7 @@ mod tests { xhr: None, form: None, }, - http_pseudo_headers: vec![], + settings: Http2Settings::default(), }, }, #[cfg(feature = "tls")] diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs index 96aabb5c..95d20cf3 100644 --- a/rama-ua/src/profile/db.rs +++ b/rama-ua/src/profile/db.rs @@ -406,7 +406,7 @@ mod tests { xhr: None, form: None, }, - title_case_headers: false, + settings: crate::Http1Settings::default(), }, h2: crate::Http2Profile { headers: crate::HttpHeadersProfile { @@ -420,7 +420,7 @@ mod tests { xhr: None, form: None, }, - http_pseudo_headers: vec![], + settings: crate::Http2Settings::default(), }, }, #[cfg(feature = "tls")] diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs index 9a231184..6139cd4c 100644 --- a/rama-ua/src/profile/http.rs +++ b/rama-ua/src/profile/http.rs @@ -24,11 +24,21 @@ pub struct HttpHeadersProfile { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Http1Profile { pub headers: HttpHeadersProfile, + pub settings: Http1Settings, +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct Http1Settings { pub title_case_headers: bool, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Http2Profile { pub headers: HttpHeadersProfile, - pub http_pseudo_headers: Vec, + pub settings: Http2Settings, +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct Http2Settings { + pub http_pseudo_headers: Option>, } From b89860b908ed4f9214059ed4d40ce9b3babbff9f Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Fri, 28 Feb 2025 11:47:12 +0100 Subject: [PATCH 37/39] fix minor issues with rama-fp storage & script next steps: integrate PG db --- rama-cli/assets/script.js | 8 ++++++++ rama-cli/src/cmd/fp/storage.rs | 16 ++++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/rama-cli/assets/script.js b/rama-cli/assets/script.js index 80b3f747..79c4c169 100644 --- a/rama-cli/assets/script.js +++ b/rama-cli/assets/script.js @@ -65,12 +65,20 @@ function makeRequestWithXHR(url, method, number) { // Main function to execute the requests async function main() { try { + // Generate random numbers for the requests + const number = Math.floor(Math.random() * 1000) + 1; + const number2 = Math.floor(Math.random() * 1000) + 1; + + console.log('Generated random numbers:', number, number2); + // Fetch POST request const response2 = await makePostRequest(`/api/fetch/number/${number}`, number); + console.log('Fetch POST request response:', response2); const result = await response2.json(); // XMLHttpRequest POST request const response4 = await makeRequestWithXHR(`/api/xml/number/${number2}`, 'POST', number2); + console.log('XMLHttpRequest POST request response:', response4); const result2 = JSON.parse(response4); console.log('Requests completed successfully'); diff --git a/rama-cli/src/cmd/fp/storage.rs b/rama-cli/src/cmd/fp/storage.rs index 6556aac0..65645fef 100644 --- a/rama-cli/src/cmd/fp/storage.rs +++ b/rama-cli/src/cmd/fp/storage.rs @@ -30,7 +30,7 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h1 headers for UA '{ua}': {headers:?}"); + tracing::info!("store h1 navigateheaders for UA '{ua}': {headers:?}"); Ok(()) } @@ -39,7 +39,7 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h1 headers for UA '{ua}': {headers:?}"); + tracing::info!("store h1 fetch headers for UA '{ua}': {headers:?}"); Ok(()) } @@ -48,7 +48,7 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h1 headers for UA '{ua}': {headers:?}"); + tracing::info!("store h1 xhr headers for UA '{ua}': {headers:?}"); Ok(()) } @@ -57,7 +57,7 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h1 headers for UA '{ua}': {headers:?}"); + tracing::info!("store h1 form headers for UA '{ua}': {headers:?}"); Ok(()) } @@ -75,7 +75,7 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h2 headers for UA '{ua}': {headers:?}"); + tracing::info!("store h2 navigate headers for UA '{ua}': {headers:?}"); Ok(()) } @@ -84,7 +84,7 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h2 headers for UA '{ua}': {headers:?}"); + tracing::info!("store h2 fetch headers for UA '{ua}': {headers:?}"); Ok(()) } @@ -93,7 +93,7 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h2 headers for UA '{ua}': {headers:?}"); + tracing::info!("store h2 xhr headers for UA '{ua}': {headers:?}"); Ok(()) } @@ -102,7 +102,7 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h2 headers for UA '{ua}': {headers:?}"); + tracing::info!("store h2 form headers for UA '{ua}': {headers:?}"); Ok(()) } From ecb7d74279aca73a7bb83008391fab63bcd93842 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sun, 2 Mar 2025 10:55:42 +0100 Subject: [PATCH 38/39] integrate actual postgres storage into rama-fp --- Cargo.lock | 199 +++++++++++++ Cargo.toml | 2 + rama-cli/Cargo.toml | 2 + rama-cli/src/cmd/fp/mod.rs | 2 +- rama-cli/src/cmd/fp/state.rs | 2 +- rama-cli/src/cmd/fp/storage.rs | 117 -------- rama-cli/src/cmd/fp/storage/mod.rs | 267 ++++++++++++++++++ rama-cli/src/cmd/fp/storage/postgres.rs | 127 +++++++++ rama-tls/src/boring/client/connector.rs | 161 +++-------- rama-tls/src/boring/client/mod.rs | 10 +- rama-tls/src/boring/client/tls_stream.rs | 85 ++++++ rama-tls/src/boring/client/tls_stream_auto.rs | 132 +++++++++ 12 files changed, 863 insertions(+), 243 deletions(-) delete mode 100644 rama-cli/src/cmd/fp/storage.rs create mode 100644 rama-cli/src/cmd/fp/storage/mod.rs create mode 100644 rama-cli/src/cmd/fp/storage/postgres.rs create mode 100644 rama-tls/src/boring/client/tls_stream.rs create mode 100644 rama-tls/src/boring/client/tls_stream_auto.rs diff --git a/Cargo.lock b/Cargo.lock index baec2b6b..053c6c9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -637,6 +637,40 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010" +[[package]] +name = "deadpool" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ed5957ff93768adf7a65ab167a17835c3d2c3c50d084fe305174c112f468e2f" +dependencies = [ + "deadpool-runtime", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-postgres" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d697d376cbfa018c23eb4caab1fd1883dd9c906a8c034e8d9a3cb06a7e0bef9" +dependencies = [ + "async-trait", + "deadpool", + "getrandom 0.2.15", + "tokio", + "tokio-postgres", + "tracing", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +dependencies = [ + "tokio", +] + [[package]] name = "deranged" version = "0.3.11" @@ -685,6 +719,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -808,6 +843,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + [[package]] name = "fastrand" version = "2.3.0" @@ -1162,6 +1203,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "hex" version = "0.4.3" @@ -1211,6 +1258,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.11" @@ -1698,6 +1754,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + [[package]] name = "md5" version = "0.7.0" @@ -1851,6 +1917,16 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.36.7" @@ -2027,6 +2103,24 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.9" @@ -2071,6 +2165,37 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +[[package]] +name = "postgres-protocol" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ff0abab4a9b844b93ef7b81f1efc0a366062aaef2cd702c76256b5dc075c54" +dependencies = [ + "base64 0.22.1", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand 0.9.0", + "sha2", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol", + "serde", + "serde_json", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -2204,6 +2329,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "deadpool-postgres", "hex", "itertools 0.14.0", "mimalloc", @@ -2213,6 +2339,7 @@ dependencies = [ "terminal-prompt", "tikv-jemallocator", "tokio", + "tokio-postgres", "tracing", "tracing-subscriber", ] @@ -2986,6 +3113,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "slab" version = "0.4.9" @@ -3032,6 +3165,17 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "strsim" version = "0.11.1" @@ -3288,6 +3432,32 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-postgres" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand 0.9.0", + "socket2", + "tokio", + "tokio-util", + "whoami", +] + [[package]] name = "tokio-rustls" version = "0.26.1" @@ -3503,6 +3673,12 @@ version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + [[package]] name = "unicode-ident" version = "1.0.17" @@ -3518,6 +3694,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-properties" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" + [[package]] name = "unicode-xid" version = "0.2.6" @@ -3637,6 +3819,12 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -3739,6 +3927,17 @@ dependencies = [ "rustix", ] +[[package]] +name = "whoami" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +dependencies = [ + "redox_syscall", + "wasite", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index dcff4dc6..a9cb1d43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ http = "1" http-body = "1" http-body-util = "0.1" http-range-header = "0.4.0" +deadpool-postgres = "0.14.1" httpdate = "1.0" boring = "4.9.1" tokio-boring = "4.9.1" @@ -142,6 +143,7 @@ slab = "0.4.2" indexmap = "2" rand = "0.9.0" walkdir = "2.3.2" +tokio-postgres = "0.7.13" env_logger = "0.11.6" httparse = "1.8" smallvec = { version = "1.12", features = ["const_generics", "const_new"] } diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 73a9ff33..8b15486a 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -22,6 +22,7 @@ mimalloc = ["dep:mimalloc"] base64 = { workspace = true } bytes = { workspace = true } clap = { workspace = true } +deadpool-postgres = { workspace = true } hex = { workspace = true } itertools = { workspace = true } jemallocator = { workspace = true, optional = true } @@ -31,6 +32,7 @@ serde = { workspace = true } serde_json = { workspace = true } terminal-prompt = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tokio-postgres = { workspace = true, features = ["with-serde_json-1"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } diff --git a/rama-cli/src/cmd/fp/mod.rs b/rama-cli/src/cmd/fp/mod.rs index 165bd78f..a052a267 100644 --- a/rama-cli/src/cmd/fp/mod.rs +++ b/rama-cli/src/cmd/fp/mod.rs @@ -289,7 +289,7 @@ pub async fn run(cfg: CliCommandFingerprint) -> Result<(), BoxError> { let pg_url = std::env::var("DATABASE_URL").ok(); let storage_auth = std::env::var("RAMA_FP_STORAGE_COOKIE").ok(); - let tcp_listener = TcpListener::build_with_state(Arc::new(State::new(acme_data, pg_url.as_deref(), storage_auth.as_deref()).await.expect("create state"))) + let tcp_listener = TcpListener::build_with_state(Arc::new(State::new(acme_data, pg_url, storage_auth.as_deref()).await.expect("create state"))) .bind(&address) .await .expect("bind TCP Listener"); diff --git a/rama-cli/src/cmd/fp/state.rs b/rama-cli/src/cmd/fp/state.rs index 898d543d..fdbe5ba2 100644 --- a/rama-cli/src/cmd/fp/state.rs +++ b/rama-cli/src/cmd/fp/state.rs @@ -17,7 +17,7 @@ impl State { /// Create a new instance of [`State`]. pub(super) async fn new( acme: ACMEData, - pg_url: Option<&str>, + pg_url: Option, storage_auth: Option<&str>, ) -> Result { let storage = match pg_url { diff --git a/rama-cli/src/cmd/fp/storage.rs b/rama-cli/src/cmd/fp/storage.rs deleted file mode 100644 index 65645fef..00000000 --- a/rama-cli/src/cmd/fp/storage.rs +++ /dev/null @@ -1,117 +0,0 @@ -use rama::{ - error::OpaqueError, - http::proto::h1::Http1HeaderMap, - net::tls::client::ClientHello, - ua::{Http1Settings, Http2Settings}, -}; - -#[derive(Debug, Clone)] -pub(super) struct Storage; - -impl Storage { - pub(super) async fn new(pg_url: &str) -> Result { - tracing::info!("create new storage with PG URL: {}", pg_url); - Ok(Self) - } -} - -impl Storage { - pub(super) async fn store_h1_settings( - &self, - ua: String, - settings: Http1Settings, - ) -> Result<(), OpaqueError> { - tracing::info!("store h1 settings for UA '{ua}': {settings:?}"); - Ok(()) - } - - pub(super) async fn store_h1_headers_navigate( - &self, - ua: String, - headers: Http1HeaderMap, - ) -> Result<(), OpaqueError> { - tracing::info!("store h1 navigateheaders for UA '{ua}': {headers:?}"); - Ok(()) - } - - pub(super) async fn store_h1_headers_fetch( - &self, - ua: String, - headers: Http1HeaderMap, - ) -> Result<(), OpaqueError> { - tracing::info!("store h1 fetch headers for UA '{ua}': {headers:?}"); - Ok(()) - } - - pub(super) async fn store_h1_headers_xhr( - &self, - ua: String, - headers: Http1HeaderMap, - ) -> Result<(), OpaqueError> { - tracing::info!("store h1 xhr headers for UA '{ua}': {headers:?}"); - Ok(()) - } - - pub(super) async fn store_h1_headers_form( - &self, - ua: String, - headers: Http1HeaderMap, - ) -> Result<(), OpaqueError> { - tracing::info!("store h1 form headers for UA '{ua}': {headers:?}"); - Ok(()) - } - - pub(super) async fn store_h2_settings( - &self, - ua: String, - settings: Http2Settings, - ) -> Result<(), OpaqueError> { - tracing::info!("store h2 settings for UA '{ua}': {settings:?}"); - Ok(()) - } - - pub(super) async fn store_h2_headers_navigate( - &self, - ua: String, - headers: Http1HeaderMap, - ) -> Result<(), OpaqueError> { - tracing::info!("store h2 navigate headers for UA '{ua}': {headers:?}"); - Ok(()) - } - - pub(super) async fn store_h2_headers_fetch( - &self, - ua: String, - headers: Http1HeaderMap, - ) -> Result<(), OpaqueError> { - tracing::info!("store h2 fetch headers for UA '{ua}': {headers:?}"); - Ok(()) - } - - pub(super) async fn store_h2_headers_xhr( - &self, - ua: String, - headers: Http1HeaderMap, - ) -> Result<(), OpaqueError> { - tracing::info!("store h2 xhr headers for UA '{ua}': {headers:?}"); - Ok(()) - } - - pub(super) async fn store_h2_headers_form( - &self, - ua: String, - headers: Http1HeaderMap, - ) -> Result<(), OpaqueError> { - tracing::info!("store h2 form headers for UA '{ua}': {headers:?}"); - Ok(()) - } - - pub(super) async fn store_tls_client_hello( - &self, - ua: String, - tls_client_hello: ClientHello, - ) -> Result<(), OpaqueError> { - tracing::info!("store tls client hello for UA '{ua}': {tls_client_hello:?}"); - Ok(()) - } -} diff --git a/rama-cli/src/cmd/fp/storage/mod.rs b/rama-cli/src/cmd/fp/storage/mod.rs new file mode 100644 index 00000000..625a2be1 --- /dev/null +++ b/rama-cli/src/cmd/fp/storage/mod.rs @@ -0,0 +1,267 @@ +use rama::{ + error::{ErrorContext, OpaqueError}, + http::proto::h1::Http1HeaderMap, + net::tls::client::ClientHello, + ua::{Http1Settings, Http2Settings}, +}; + +mod postgres; +use postgres::Pool; +use tokio_postgres::types; + +#[derive(Debug, Clone)] +pub(super) struct Storage { + pool: Pool, +} + +impl Storage { + pub(super) async fn new(pg_url: String) -> Result { + tracing::info!("create new storage with PG URL: {}", pg_url); + let pool = postgres::new_pool(pg_url).await?; + Ok(Self { pool }) + } +} + +impl Storage { + pub(super) async fn store_h1_settings( + &self, + ua: String, + settings: Http1Settings, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h1 settings for UA '{ua}': {settings:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h1_settings (ua, h1_settings) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_settings = $2", + &[&ua, &types::Json(settings)], + ).await.context("store h1 settings in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 settings for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h1_headers_navigate( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h1 navigateheaders for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h1_headers_navigate (ua, h1_headers_navigate) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_headers_navigate = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h1 navigate headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 navigate headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h1_headers_fetch( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h1 fetch headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h1_headers_fetch (ua, h1_headers_fetch) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_headers_fetch = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h1 fetch headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 fetch headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h1_headers_xhr( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h1 xhr headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h1_headers_xhr (ua, h1_headers_xhr) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_headers_xhr = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h1 xhr headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 xhr headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h1_headers_form( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h1 form headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h1_headers_form (ua, h1_headers_form) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_headers_form = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h1 form headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 form headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_settings( + &self, + ua: String, + settings: Http2Settings, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 settings for UA '{ua}': {settings:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h2_settings (ua, h2_settings) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_settings = $2", + &[&ua, &types::Json(settings)], + ).await.context("store h2 settings in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 settings for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_headers_navigate( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 navigate headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h2_headers_navigate (ua, h2_headers_navigate) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_headers_navigate = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h2 navigate headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 navigate headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_headers_fetch( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 fetch headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h2_headers_fetch (ua, h2_headers_fetch) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_headers_fetch = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h2 fetch headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 fetch headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_headers_xhr( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 xhr headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h2_headers_xhr (ua, h2_headers_xhr) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_headers_xhr = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h2 xhr headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 xhr headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_headers_form( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::info!("store h2 form headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO h2_headers_form (ua, h2_headers_form) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_headers_form = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h2 form headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 form headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_tls_client_hello( + &self, + ua: String, + tls_client_hello: ClientHello, + ) -> Result<(), OpaqueError> { + tracing::info!("store tls client hello for UA '{ua}': {tls_client_hello:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO tls_client_hello (ua, tls_client_hello) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET tls_client_hello = $2", + &[&ua, &types::Json(tls_client_hello)], + ).await.context("store tls client hello in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store tls client hello for UA '{ua}': {n}" + ); + } + + Ok(()) + } +} diff --git a/rama-cli/src/cmd/fp/storage/postgres.rs b/rama-cli/src/cmd/fp/storage/postgres.rs new file mode 100644 index 00000000..bb55cdae --- /dev/null +++ b/rama-cli/src/cmd/fp/storage/postgres.rs @@ -0,0 +1,127 @@ +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +use deadpool_postgres::Config; +use rama::{ + error::{ErrorContext, OpaqueError}, + net::{address::Host, stream::Stream}, + tls::{ + boring::client::{TlsStream, tls_connect}, + std::dep::boring::{hash::MessageDigest, nid::Nid, ssl::SslRef}, + }, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_postgres::tls::{self, ChannelBinding, MakeTlsConnect, TlsConnect}; + +pub(super) use deadpool_postgres::Pool; + +pub(super) async fn new_pool(url: String) -> Result { + Config { + url: Some(url), + ..Default::default() + } + .create_pool(None, MakeBoringTlsConnector) + .context("create postgres deadpool") +} + +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +struct MakeBoringTlsConnector; + +#[derive(Debug, Clone)] +struct BoringTlsConnector { + host: Host, +} + +impl MakeTlsConnect for MakeBoringTlsConnector +where + S: Stream + Unpin, +{ + type Stream = BoringTlsStream; + type TlsConnect = BoringTlsConnector; + type Error = OpaqueError; + + fn make_tls_connect(&mut self, domain: &str) -> Result { + let host: Host = domain.parse().context("parse host from domain")?; + Ok(BoringTlsConnector { host }) + } +} + +impl TlsConnect for BoringTlsConnector +where + S: Stream + Unpin, +{ + type Stream = BoringTlsStream; + type Error = OpaqueError; + #[allow(clippy::type_complexity)] + type Future = Pin, Self::Error>> + Send>>; + + fn connect(self, stream: S) -> Self::Future { + Box::pin(async move { + let tls_stream = tls_connect(self.host, stream, None).await?; + Ok(BoringTlsStream(tls_stream)) + }) + } +} +/// The stream returned by `TlsConnector`. +struct BoringTlsStream(TlsStream); + +impl AsyncRead for BoringTlsStream +where + S: Stream + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl AsyncWrite for BoringTlsStream +where + S: Stream + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +impl tls::TlsStream for BoringTlsStream +where + S: Stream + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match tls_server_end_point(self.0.ssl_ref()) { + Some(buf) => ChannelBinding::tls_server_end_point(buf), + None => ChannelBinding::none(), + } + } +} + +fn tls_server_end_point(ssl: &SslRef) -> Option> { + let cert = ssl.peer_certificate()?; + let algo_nid = cert.signature_algorithm().object().nid(); + let signature_algorithms = algo_nid.signature_algorithms()?; + let md = match signature_algorithms.digest { + Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(), + nid => MessageDigest::from_nid(nid)?, + }; + cert.digest(md).ok().map(|b| b.to_vec()) +} diff --git a/rama-tls/src/boring/client/connector.rs b/rama-tls/src/boring/client/connector.rs index 43d2ed3b..0175d281 100644 --- a/rama-tls/src/boring/client/connector.rs +++ b/rama-tls/src/boring/client/connector.rs @@ -1,6 +1,3 @@ -use super::TlsConnectorData; -use crate::types::TlsTunnel; -use pin_project_lite::pin_project; use private::{ConnectorKindAuto, ConnectorKindSecure, ConnectorKindTunnel}; use rama_core::error::{BoxError, ErrorContext, ErrorExt, OpaqueError}; use rama_core::{Context, Layer, Service}; @@ -12,9 +9,11 @@ use rama_net::tls::client::{ClientConfig, NegotiatedTlsParameters}; use rama_net::transport::TryRefIntoTransportContext; use std::fmt; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite}; use tokio_boring::SslStream; +use super::{AutoTlsStream, TlsConnectorData, TlsStream}; +use crate::types::TlsTunnel; + /// A [`Layer`] which wraps the given service with a [`TlsConnector`]. /// /// See [`TlsConnector`] for more information. @@ -254,9 +253,7 @@ where return Ok(EstablishedClientConnection { ctx, req, - conn: AutoTlsStream { - inner: AutoTlsStreamData::Plain { inner: conn }, - }, + conn: AutoTlsStream::plain(conn), }); } @@ -289,9 +286,7 @@ where Ok(EstablishedClientConnection { ctx, req, - conn: AutoTlsStream { - inner: AutoTlsStreamData::Secure { inner: stream }, - }, + conn: AutoTlsStream::secure(stream), }) } } @@ -304,7 +299,7 @@ where + Send + 'static, { - type Response = EstablishedClientConnection, State, Request>; + type Response = EstablishedClientConnection, State, Request>; type Error = BoxError; async fn serve( @@ -345,6 +340,7 @@ where }; let (conn, negotiated_params) = self.handshake(connector_data, host, conn).await?; + let conn = TlsStream::new(conn); ctx.insert(negotiated_params); Ok(EstablishedClientConnection { ctx, req, conn }) @@ -382,9 +378,7 @@ where return Ok(EstablishedClientConnection { ctx, req, - conn: AutoTlsStream { - inner: AutoTlsStreamData::Plain { inner: conn }, - }, + conn: AutoTlsStream::plain(conn), }); } }; @@ -411,13 +405,39 @@ where Ok(EstablishedClientConnection { ctx, req, - conn: AutoTlsStream { - inner: AutoTlsStreamData::Secure { inner: stream }, - }, + conn: AutoTlsStream::secure(stream), }) } } +pub async fn tls_connect( + server_host: Host, + stream: T, + connector_data: Option<&TlsConnectorData>, +) -> Result, OpaqueError> +where + T: Stream + Unpin, +{ + let client_config_data = match connector_data { + Some(connector_data) => connector_data.try_to_build_config()?, + None => TlsConnectorData::new()?.try_to_build_config()?, + }; + let server_host = client_config_data.server_name.unwrap_or(server_host); + let stream = tokio_boring::connect( + client_config_data.config, + server_host.to_string().as_str(), + stream, + ) + .await + .map_err(|err| match err.as_io_error() { + Some(err) => OpaqueError::from_display(err.to_string()) + .context("boring ssl connector: connect") + .into_boxed(), + None => OpaqueError::from_display("boring ssl connector: connect").into_boxed(), + })?; + Ok(TlsStream::new(stream)) +} + impl TlsConnector { async fn handshake( &self, @@ -429,23 +449,8 @@ impl TlsConnector { T: Stream + Unpin, { let connector_data = connector_data.as_ref().or(self.connector_data.as_ref()); - let client_config_data = match connector_data { - Some(connector_data) => connector_data.try_to_build_config()?, - None => TlsConnectorData::new()?.try_to_build_config()?, - }; - let server_host = client_config_data.server_name.unwrap_or(server_host); - let stream = tokio_boring::connect( - client_config_data.config, - server_host.to_string().as_str(), - stream, - ) - .await - .map_err(|err| match err.as_io_error() { - Some(err) => OpaqueError::from_display(err.to_string()) - .context("boring ssl connector: connect") - .into_boxed(), - None => OpaqueError::from_display("boring ssl connector: connect").into_boxed(), - })?; + + let TlsStream { inner: stream } = tls_connect(server_host, stream, connector_data).await?; let params = match stream.ssl().session() { Some(ssl_session) => { @@ -490,94 +495,6 @@ impl TlsConnector { } } -pin_project! { - /// A stream which can be either a secure or a plain stream. - pub struct AutoTlsStream { - #[pin] - inner: AutoTlsStreamData, - } -} - -impl fmt::Debug for AutoTlsStream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("AutoTlsStream") - .field("inner", &self.inner) - .finish() - } -} - -pin_project! { - #[project = AutoTlsStreamDataProj] - /// A stream which can be either a secure or a plain stream. - enum AutoTlsStreamData { - /// A secure stream. - Secure{ #[pin] inner: SslStream }, - /// A plain stream. - Plain { #[pin] inner: S }, - } -} - -impl fmt::Debug for AutoTlsStreamData { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - AutoTlsStreamData::Secure { inner } => f.debug_tuple("Secure").field(inner).finish(), - AutoTlsStreamData::Plain { inner } => f.debug_tuple("Plain").field(inner).finish(), - } - } -} - -impl AsyncRead for AutoTlsStream -where - S: Stream + Unpin, -{ - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - match self.project().inner.project() { - AutoTlsStreamDataProj::Secure { inner } => inner.poll_read(cx, buf), - AutoTlsStreamDataProj::Plain { inner } => inner.poll_read(cx, buf), - } - } -} - -impl AsyncWrite for AutoTlsStream -where - S: Stream + Unpin, -{ - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - match self.project().inner.project() { - AutoTlsStreamDataProj::Secure { inner } => inner.poll_write(cx, buf), - AutoTlsStreamDataProj::Plain { inner } => inner.poll_write(cx, buf), - } - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.project().inner.project() { - AutoTlsStreamDataProj::Secure { inner } => inner.poll_flush(cx), - AutoTlsStreamDataProj::Plain { inner } => inner.poll_flush(cx), - } - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.project().inner.project() { - AutoTlsStreamDataProj::Secure { inner } => inner.poll_shutdown(cx), - AutoTlsStreamDataProj::Plain { inner } => inner.poll_shutdown(cx), - } - } -} - mod private { use rama_net::address::Host; diff --git a/rama-tls/src/boring/client/mod.rs b/rama-tls/src/boring/client/mod.rs index 5e97fd8a..a44bfbd7 100644 --- a/rama-tls/src/boring/client/mod.rs +++ b/rama-tls/src/boring/client/mod.rs @@ -1,8 +1,14 @@ -//! TLS client support for Rama. +//! TLS (Boring) client support for Rama. + +mod tls_stream_auto; +pub use tls_stream_auto::AutoTlsStream; + +mod tls_stream; +pub use tls_stream::TlsStream; mod connector; #[doc(inline)] -pub use connector::{AutoTlsStream, TlsConnector, TlsConnectorLayer}; +pub use connector::{TlsConnector, TlsConnectorLayer, tls_connect}; mod connector_data; #[doc(inline)] diff --git a/rama-tls/src/boring/client/tls_stream.rs b/rama-tls/src/boring/client/tls_stream.rs new file mode 100644 index 00000000..0bcc1b00 --- /dev/null +++ b/rama-tls/src/boring/client/tls_stream.rs @@ -0,0 +1,85 @@ +use std::fmt; + +use boring::ssl::SslRef; +use pin_project_lite::pin_project; +use rama_net::stream::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_boring::SslStream; + +pin_project! { + /// A stream which can be either a secure or a plain stream. + pub struct TlsStream { + #[pin] + pub(super) inner: SslStream, + } +} + +impl TlsStream { + pub(super) fn new(inner: SslStream) -> Self { + Self { inner } + } + + pub fn ssl_ref(&self) -> &SslRef { + self.inner.ssl() + } +} + +impl fmt::Debug for TlsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsStream") + .field("inner", &self.inner) + .finish() + } +} + +impl AsyncRead for TlsStream +where + S: Stream + Unpin, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_read(cx, buf) + } +} + +impl AsyncWrite for TlsStream +where + S: Stream + Unpin, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + false + } +} diff --git a/rama-tls/src/boring/client/tls_stream_auto.rs b/rama-tls/src/boring/client/tls_stream_auto.rs new file mode 100644 index 00000000..0888dbe3 --- /dev/null +++ b/rama-tls/src/boring/client/tls_stream_auto.rs @@ -0,0 +1,132 @@ +use std::fmt; + +use boring::ssl::SslRef; +use pin_project_lite::pin_project; +use rama_net::stream::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_boring::SslStream; + +pin_project! { + /// A stream which can be either a secure or a plain stream. + pub struct AutoTlsStream { + #[pin] + inner: AutoTlsStreamData, + } +} + +impl AutoTlsStream { + pub(super) fn secure(inner: SslStream) -> Self { + Self { + inner: AutoTlsStreamData::Secure { inner }, + } + } + + pub(super) fn plain(inner: S) -> Self { + Self { + inner: AutoTlsStreamData::Plain { inner }, + } + } + + pub fn ssl_ref(&self) -> Option<&SslRef> { + match &self.inner { + AutoTlsStreamData::Secure { inner } => Some(inner.ssl()), + AutoTlsStreamData::Plain { .. } => None, + } + } +} + +impl fmt::Debug for AutoTlsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AutoTlsStream") + .field("inner", &self.inner) + .finish() + } +} + +pin_project! { + #[project = AutoTlsStreamDataProj] + /// A stream which can be either a secure or a plain stream. + enum AutoTlsStreamData { + /// A secure stream. + Secure{ #[pin] inner: SslStream }, + /// A plain stream. + Plain { #[pin] inner: S }, + } +} + +impl fmt::Debug for AutoTlsStreamData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AutoTlsStreamData::Secure { inner } => f.debug_tuple("Secure").field(inner).finish(), + AutoTlsStreamData::Plain { inner } => f.debug_tuple("Plain").field(inner).finish(), + } + } +} + +impl AsyncRead for AutoTlsStream +where + S: Stream + Unpin, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.project().inner.project() { + AutoTlsStreamDataProj::Secure { inner } => inner.poll_read(cx, buf), + AutoTlsStreamDataProj::Plain { inner } => inner.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for AutoTlsStream +where + S: Stream + Unpin, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.project().inner.project() { + AutoTlsStreamDataProj::Secure { inner } => inner.poll_write(cx, buf), + AutoTlsStreamDataProj::Plain { inner } => inner.poll_write(cx, buf), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project().inner.project() { + AutoTlsStreamDataProj::Secure { inner } => inner.poll_flush(cx), + AutoTlsStreamDataProj::Plain { inner } => inner.poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project().inner.project() { + AutoTlsStreamDataProj::Secure { inner } => inner.poll_shutdown(cx), + AutoTlsStreamDataProj::Plain { inner } => inner.poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + let buf = bufs + .iter() + .find(|b| !b.is_empty()) + .map_or(&[][..], |b| &**b); + self.poll_write(cx, buf) + } + + fn is_write_vectored(&self) -> bool { + false + } +} From ef26145f8cedaddb3da787cfc96dae4dd94fd111 Mon Sep 17 00:00:00 2001 From: Glen De Cauwsemaecker Date: Sun, 2 Mar 2025 11:10:10 +0100 Subject: [PATCH 39/39] fix fp storage minor issues in code/queries --- rama-cli/src/cmd/fp/storage/mod.rs | 44 ++++++++++++------------- rama-cli/src/cmd/fp/storage/postgres.rs | 1 + 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/rama-cli/src/cmd/fp/storage/mod.rs b/rama-cli/src/cmd/fp/storage/mod.rs index 625a2be1..c0c46be5 100644 --- a/rama-cli/src/cmd/fp/storage/mod.rs +++ b/rama-cli/src/cmd/fp/storage/mod.rs @@ -16,7 +16,7 @@ pub(super) struct Storage { impl Storage { pub(super) async fn new(pg_url: String) -> Result { - tracing::info!("create new storage with PG URL: {}", pg_url); + tracing::debug!("create new storage with PG URL: {}", pg_url); let pool = postgres::new_pool(pg_url).await?; Ok(Self { pool }) } @@ -32,7 +32,7 @@ impl Storage { let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h1_settings (ua, h1_settings) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_settings = $2", + "INSERT INTO \"ua-profiles\" (uastr, h1_settings) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_settings = $2", &[&ua, &types::Json(settings)], ).await.context("store h1 settings in postgres")?; @@ -50,11 +50,11 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h1 navigateheaders for UA '{ua}': {headers:?}"); + tracing::debug!("store h1 navigateheaders for UA '{ua}': {headers:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h1_headers_navigate (ua, h1_headers_navigate) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_headers_navigate = $2", + "INSERT INTO \"ua-profiles\" (uastr, h1_headers_navigate) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_headers_navigate = $2", &[&ua, &types::Json(headers)], ).await.context("store h1 navigate headers in postgres")?; @@ -72,11 +72,11 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h1 fetch headers for UA '{ua}': {headers:?}"); + tracing::debug!("store h1 fetch headers for UA '{ua}': {headers:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h1_headers_fetch (ua, h1_headers_fetch) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_headers_fetch = $2", + "INSERT INTO \"ua-profiles\" (uastr, h1_headers_fetch) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_headers_fetch = $2", &[&ua, &types::Json(headers)], ).await.context("store h1 fetch headers in postgres")?; @@ -94,11 +94,11 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h1 xhr headers for UA '{ua}': {headers:?}"); + tracing::debug!("store h1 xhr headers for UA '{ua}': {headers:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h1_headers_xhr (ua, h1_headers_xhr) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_headers_xhr = $2", + "INSERT INTO \"ua-profiles\" (uastr, h1_headers_xhr) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_headers_xhr = $2", &[&ua, &types::Json(headers)], ).await.context("store h1 xhr headers in postgres")?; @@ -116,11 +116,11 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h1 form headers for UA '{ua}': {headers:?}"); + tracing::debug!("store h1 form headers for UA '{ua}': {headers:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h1_headers_form (ua, h1_headers_form) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h1_headers_form = $2", + "INSERT INTO \"ua-profiles\" (uastr, h1_headers_form) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_headers_form = $2", &[&ua, &types::Json(headers)], ).await.context("store h1 form headers in postgres")?; @@ -138,11 +138,11 @@ impl Storage { ua: String, settings: Http2Settings, ) -> Result<(), OpaqueError> { - tracing::info!("store h2 settings for UA '{ua}': {settings:?}"); + tracing::debug!("store h2 settings for UA '{ua}': {settings:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h2_settings (ua, h2_settings) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_settings = $2", + "INSERT INTO \"ua-profiles\" (uastr, h2_settings) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_settings = $2", &[&ua, &types::Json(settings)], ).await.context("store h2 settings in postgres")?; @@ -160,11 +160,11 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h2 navigate headers for UA '{ua}': {headers:?}"); + tracing::debug!("store h2 navigate headers for UA '{ua}': {headers:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h2_headers_navigate (ua, h2_headers_navigate) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_headers_navigate = $2", + "INSERT INTO \"ua-profiles\" (uastr, h2_headers_navigate) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_headers_navigate = $2", &[&ua, &types::Json(headers)], ).await.context("store h2 navigate headers in postgres")?; @@ -182,11 +182,11 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h2 fetch headers for UA '{ua}': {headers:?}"); + tracing::debug!("store h2 fetch headers for UA '{ua}': {headers:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h2_headers_fetch (ua, h2_headers_fetch) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_headers_fetch = $2", + "INSERT INTO \"ua-profiles\" (uastr, h2_headers_fetch) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_headers_fetch = $2", &[&ua, &types::Json(headers)], ).await.context("store h2 fetch headers in postgres")?; @@ -204,11 +204,11 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h2 xhr headers for UA '{ua}': {headers:?}"); + tracing::debug!("store h2 xhr headers for UA '{ua}': {headers:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h2_headers_xhr (ua, h2_headers_xhr) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_headers_xhr = $2", + "INSERT INTO \"ua-profiles\" (uastr, h2_headers_xhr) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_headers_xhr = $2", &[&ua, &types::Json(headers)], ).await.context("store h2 xhr headers in postgres")?; @@ -226,11 +226,11 @@ impl Storage { ua: String, headers: Http1HeaderMap, ) -> Result<(), OpaqueError> { - tracing::info!("store h2 form headers for UA '{ua}': {headers:?}"); + tracing::debug!("store h2 form headers for UA '{ua}': {headers:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO h2_headers_form (ua, h2_headers_form) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET h2_headers_form = $2", + "INSERT INTO \"ua-profiles\" (uastr, h2_headers_form) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_headers_form = $2", &[&ua, &types::Json(headers)], ).await.context("store h2 form headers in postgres")?; @@ -248,11 +248,11 @@ impl Storage { ua: String, tls_client_hello: ClientHello, ) -> Result<(), OpaqueError> { - tracing::info!("store tls client hello for UA '{ua}': {tls_client_hello:?}"); + tracing::debug!("store tls client hello for UA '{ua}': {tls_client_hello:?}"); let client = self.pool.get().await.context("get postgres client")?; let n = client.execute( - "INSERT INTO tls_client_hello (ua, tls_client_hello) VALUES ($1, $2) ON CONFLICT (ua) DO UPDATE SET tls_client_hello = $2", + "INSERT INTO \"ua-profiles\" (uastr, tls_client_hello) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET tls_client_hello = $2", &[&ua, &types::Json(tls_client_hello)], ).await.context("store tls client hello in postgres")?; diff --git a/rama-cli/src/cmd/fp/storage/postgres.rs b/rama-cli/src/cmd/fp/storage/postgres.rs index bb55cdae..90fe12c1 100644 --- a/rama-cli/src/cmd/fp/storage/postgres.rs +++ b/rama-cli/src/cmd/fp/storage/postgres.rs @@ -21,6 +21,7 @@ pub(super) use deadpool_postgres::Pool; pub(super) async fn new_pool(url: String) -> Result { Config { url: Some(url), + dbname: Some("fp".to_owned()), ..Default::default() } .create_pool(None, MakeBoringTlsConnector)