From d3ff3a568cf87bc42453531733c3d8e03267438e Mon Sep 17 00:00:00 2001 From: Li Jie Date: Mon, 30 Dec 2024 18:11:31 +0800 Subject: [PATCH] port rack/query_parser --- Cargo.lock | 9 + Cargo.toml | 6 +- src/error.rs | 5 +- src/lib.rs | 2 + src/params.rs | 342 +++++++++--------------------- src/query_parser.rs | 499 ++++++++++++++++++++++++++++++++++++++++++++ src/serde.rs | 84 ++++---- src/tests.rs | 303 +++++++++------------------ src/value.rs | 300 +++++++++++++++++--------- 9 files changed, 958 insertions(+), 592 deletions(-) create mode 100644 src/query_parser.rs diff --git a/Cargo.lock b/Cargo.lock index 4c30607..1f1fc1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -203,10 +203,13 @@ dependencies = [ "axum-macros", "axum-test", "env_logger", + "form_urlencoded", "futures-util", "log", + "maplit", "mime", "multer", + "pretty_assertions", "serde", "serde_json", "serde_urlencoded", @@ -768,6 +771,12 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "matchit" version = "0.7.3" diff --git a/Cargo.toml b/Cargo.toml index f4b0104..0799814 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,12 +11,14 @@ categories = ["web-programming"] [dependencies] actson = "1.1.0" async-trait = "0.1.83" -axum = { version = "0.7.9", features = ["multipart", "macros"] } +axum = { version = "0.7", features = ["multipart", "macros"] } axum-macros = "0.4.2" +form_urlencoded = "1.2.1" log = "0.4.20" mime = "0.3.17" multer = "3.0.0" serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0.134" serde_urlencoded = "0.7.1" tempfile = "3.8.1" tokio = { version = "1.34.0", features = ["full"] } @@ -26,6 +28,8 @@ url = "2.5.4" axum-test = "16.4.1" env_logger = "0.11.6" futures-util = "0.3.29" +maplit = "1.0.2" +pretty_assertions = "1.4.0" serde_json = "1.0.134" [[example]] diff --git a/src/error.rs b/src/error.rs index 2ca749c..4c4db80 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,10 +8,7 @@ pub enum Error { DecodeError(String), ReadError(String), IOError(String), - InvalidRequest, - InvalidParams, - InvalidPath, - InvalidFile, + MergeError(String), } impl IntoResponse for Error { diff --git a/src/lib.rs b/src/lib.rs index 0ff1234..49f6015 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ mod error; mod params; +pub mod query_parser; mod serde; #[cfg(test)] mod tests; @@ -9,6 +10,7 @@ mod value; pub use error::*; pub use params::*; +pub use query_parser::*; pub use serde::*; pub use traits::*; pub use upload_file::*; diff --git a/src/params.rs b/src/params.rs index e95dff1..44f0148 100644 --- a/src/params.rs +++ b/src/params.rs @@ -1,4 +1,4 @@ -use crate::{merge_json, Error, ParamsValue, UploadFile}; +use crate::{parse_json, Error, QueryParser, UploadFile, Value}; use ::serde::de::DeserializeOwned; use actson::feeder::SliceJsonFeeder; use axum::{ @@ -10,7 +10,6 @@ use axum::{ use log::debug; use std::collections::HashMap; use tempfile::NamedTempFile; -use url::form_urlencoded; #[derive(Debug, Default)] pub struct Params(pub T, pub Vec); @@ -28,47 +27,38 @@ where req.method() == http::Method::GET || req.method() == http::Method::HEAD; let (mut parts, body) = req.into_parts(); - // Start with empty vec to preserve multiple values for the same key - let mut merged: HashMap> = HashMap::new(); + let parser = QueryParser::new(None); + let mut merged_params = HashMap::new(); // Extract path parameters if let Ok(Path(params)) = Path::>::from_request_parts(&mut parts, state).await { debug!("params: {:?}", params); + for (key, value) in params { - // Remove query string from path parameter if present - let value = if let Some(pos) = value.find('?') { - value[..pos].to_string() - } else { - value - }; - merged - .entry(key) - .or_default() - .push(ParamsValue::Convertible(value)); + parser + .parse_nested_value(&mut merged_params, key.as_str(), Value::xstr(value)) + .map_err(|e| { + Error::DecodeError(format!("Failed to parse path parameters: {}", e)) + })?; } } - debug!("merged path params: {:?}", merged); + debug!("merged path params: {:?}", merged_params); debug!("parts.uri: {:?}", parts.uri); debug!("parts.uri.query(): {:?}", parts.uri.query()); // Extract query parameters from URI if let Some(query) = parts.uri.query() { - let params: Vec<_> = form_urlencoded::parse(query.as_bytes()) - .into_owned() - .collect(); - debug!("query params: {:?}", params); - for (key, value) in params { - merged - .entry(key) - .or_default() - .push(ParamsValue::Convertible(value)); - } + parser + .parse_nested_query_into(&mut merged_params, query) + .map_err(|e| { + Error::DecodeError(format!("Failed to parse query parameters: {}", e)) + })?; } - debug!("merged query params: {:?}", merged); + debug!("merged query params: {:?}", merged_params); let mut temp_files = Vec::new(); debug!( @@ -85,8 +75,13 @@ where Error::DecodeError(format!("Failed to read JSON request body: {}", e)) })?; let feeder = SliceJsonFeeder::new(&bytes); - merge_json(feeder, &mut merged)?; - debug!("merged json: {:#?}", merged); + let value = parse_json(feeder)?; + debug!("parsed json: {:#?}", value); + merged_params = value.merge_into(merged_params).map_err(|e| { + debug!("Failed to merge JSON data: {e:?}"); + Error::DecodeError(format!("Failed to merge JSON data: {e:?}")) + })?; + debug!("merged json: {:#?}", merged_params); } ct if ct.starts_with("application/x-www-form-urlencoded") => { if !is_get_or_head { @@ -95,25 +90,17 @@ where "Failed to read form-urlencoded request body: {e}" )) })?; - if let Ok(map) = - serde_urlencoded::from_bytes::>(&bytes) - .map_err(|err| -> Error { - debug!( - "Failed to deserialize form-urlencoded data: {}", - err - ); - Error::DecodeError(format!( - "Failed to deserialize form: {err}", - )) - }) - { - for (k, v) in map { - merged - .entry(k) - .or_default() - .push(ParamsValue::Convertible(v)); - } - } + parser + .parse_nested_query_into( + &mut merged_params, + String::from_utf8_lossy(&bytes).as_ref(), + ) + .map_err(|e| { + Error::DecodeError(format!( + "Failed to parse form-urlencoded body: {}", + e + )) + })? } } ct if ct.starts_with("multipart/form-data") => { @@ -137,7 +124,7 @@ where let bytes = field.bytes().await.map_err(|e| { debug!("Failed to read JSON field bytes: {}", e); Error::ReadError(format!( - "Failed to read JSON field bytes: {e}" + "Failed to read JSON field bytes: {e}", )) })?; debug!( @@ -145,32 +132,33 @@ where String::from_utf8(bytes.to_vec()).unwrap() ); let feeder = SliceJsonFeeder::new(&bytes); - let mut temp_map = HashMap::new(); - merge_json(feeder, &mut temp_map)?; - debug!("Parsed JSON field: {:#?}", temp_map); + let value = parse_json(feeder)?; + debug!("Parsed JSON field: {:#?}", value); let name = name.unwrap_or_default(); if name.is_empty() { - // If no field name, clear all existing data and merge only the JSON data - for (key, values) in temp_map { - merged.insert(key, values); - } - debug!("Merged JSON field: {:#?}", merged); - continue; - } - - // If we have a single value in the map with key "", use it as the value - if let Some(values) = temp_map.get("") { - if values.len() == 1 { - merged.insert(name, values.clone()); - continue; - } + merged_params = + value.merge_into(merged_params).map_err(|e| { + debug!("Failed to merge JSON field: {e:?}"); + Error::DecodeError(format!( + "Failed to merge JSON field: {e:?}", + )) + })?; + } else { + parser + .parse_nested_value( + &mut merged_params, + name.as_str(), + value, + ) + .map_err(|e| { + Error::DecodeError(format!( + "Failed to parse JSON field: {}", + e + )) + })?; } - // Otherwise, process the map as nested parameters - let value = process_nested_params(temp_map); - merged.insert(name, vec![value]); - - debug!("Merged JSON field: {:#?}", merged); + debug!("Merged JSON field: {:#?}", merged_params); continue; } if let Some(name) = field.name() { @@ -222,22 +210,27 @@ where debug!("Total bytes written to file: {}", total_bytes); - merged - .entry(name) - .or_default() - .push(ParamsValue::UploadFile(UploadFile { - name: field.file_name().unwrap().to_string(), - content_type: field - .content_type() - .map(|ct| ct.to_string()) - .unwrap_or_else(|| { - "application/octet-stream".to_string() - }), - temp_file_path: temp_file - .path() - .to_string_lossy() - .to_string(), - })); + let file = Value::UploadFile(UploadFile { + name: field.file_name().unwrap().to_string(), + content_type: field + .content_type() + .map(|ct| ct.to_string()) + .unwrap_or_else(|| { + "application/octet-stream".to_string() + }), + temp_file_path: temp_file + .path() + .to_string_lossy() + .to_string(), + }); + parser + .parse_nested_value(&mut merged_params, name.as_str(), file) + .map_err(|e| { + Error::DecodeError(format!( + "Failed to parse file upload field: {}", + e + )) + })?; // Store the temp file temp_files.push(temp_file); @@ -247,10 +240,18 @@ where debug!("Failed to read text field: {}", e); Error::ReadError(format!("Failed to read text field: {e}",)) })?; - merged - .entry(name) - .or_default() - .push(ParamsValue::Convertible(value)); + parser + .parse_nested_value( + &mut merged_params, + name.as_str(), + Value::xstr(value), + ) + .map_err(|e| { + Error::DecodeError(format!( + "Failed to parse text field: {}", + e + )) + })?; } } } @@ -261,165 +262,10 @@ where } } } - let merged = process_nested_params(merged); - debug!("merged: {:?}", merged); - T::deserialize(merged) + + debug!("merged: {:?}", merged_params); + T::deserialize(Value::Object(merged_params)) .map_err(|e| Error::DecodeError(format!("Failed to deserialize parameters: {e}"))) .map(|payload| Params(payload, temp_files)) } } - -pub fn process_nested_params(grouped: HashMap>) -> ParamsValue { - debug!("Starting process_nested_params with input: {:?}", grouped); - let mut result = HashMap::new(); - - // Process each group - for (key, values) in grouped { - debug!("Processing key: {} with values: {:?}", key, values); - let parts = parse_key_parts(&key); - debug!("Parsed parts: {:?}", parts); - if parts.is_empty() { - continue; - } - - // For single-part keys, directly add the value - if parts.len() == 1 { - let value = if values.len() == 1 { - values.into_iter().next().unwrap() - } else { - ParamsValue::Array(values) - }; - debug!( - "Adding single-part key: {} with value: {:?}", - parts[0], value - ); - result.insert(parts[0].clone(), value); - continue; - } - - // Get the value from insert_nested_values and store it in the result - let value = insert_nested_values(&mut result, &parts, values); - if parts.len() == 1 { - debug!("Adding nested key: {} with value: {:?}", parts[0], value); - result.insert(parts[0].clone(), value); - } - } - - debug!("Final result: {:?}", result); - ParamsValue::Object(result) -} - -fn insert_nested_values( - map: &mut HashMap, - parts: &[String], - values: Vec, -) -> ParamsValue { - if parts.is_empty() { - return values - .into_iter() - .next() - .unwrap_or_else(|| ParamsValue::Object(HashMap::new())); - } - - let key = &parts[0]; - if parts.len() == 1 { - let value = if values.len() == 1 { - values.into_iter().next().unwrap() - } else { - ParamsValue::Array(values) - }; - return value; - } - - // Check if next part indicates an array - let is_array = parts - .get(1) - .map(|p| p.is_empty() || p.parse::().is_ok()) - .unwrap_or(false); - - let entry = map.entry(key.clone()).or_insert_with(|| { - if is_array { - ParamsValue::Array(Vec::new()) - } else { - ParamsValue::Object(HashMap::new()) - } - }); - - match entry { - ParamsValue::Object(nested_map) => { - let value = insert_nested_values(nested_map, &parts[1..], values); - if parts.len() == 2 { - nested_map.insert(parts[1].clone(), value.clone()); - } - ParamsValue::Object(nested_map.clone()) - } - ParamsValue::Array(vec) => { - if parts.get(1).map(|p| p.is_empty()).unwrap_or(false) { - vec.extend(values); - } else if let Some(Ok(index)) = parts.get(1).map(|p| p.parse::()) { - while vec.len() <= index { - vec.push(ParamsValue::Object(HashMap::new())); - } - - if parts.len() == 2 { - if let Some(value) = values.into_iter().next() { - vec[index] = value; - } - } else if let ParamsValue::Object(nested_map) = &mut vec[index] { - let value = insert_nested_values(nested_map, &parts[2..], values); - if parts.len() == 3 { - nested_map.insert(parts[2].clone(), value); - } - } - } - ParamsValue::Array(vec.clone()) - } - _ => values - .into_iter() - .next() - .unwrap_or_else(|| ParamsValue::Object(HashMap::new())), - } -} - -fn parse_key_parts(key: &str) -> Vec { - debug!("Parsing key parts for: {}", key); - let mut parts = Vec::new(); - let mut current = String::new(); - let mut in_brackets = false; - - for c in key.chars() { - match c { - '[' => { - if !current.is_empty() { - debug!("Adding part before bracket: {}", current); - parts.push(current.clone()); - current.clear(); - } - in_brackets = true; - } - ']' => { - if in_brackets { - if current.is_empty() { - debug!("Found empty brackets"); - parts.push(String::new()); - } else { - debug!("Adding part from bracket: {}", current); - parts.push(current.clone()); - } - current.clear(); - } - in_brackets = false; - } - _ => { - current.push(c); - } - } - } - - if !current.is_empty() { - debug!("Adding remaining part: {}", current); - parts.push(current); - } - - parts -} diff --git a/src/query_parser.rs b/src/query_parser.rs new file mode 100644 index 0000000..29ab0fe --- /dev/null +++ b/src/query_parser.rs @@ -0,0 +1,499 @@ +// Port from: https://github.com/rack/rack/blob/main/lib/rack/query_parser.rb + +use form_urlencoded; +use std::collections::HashMap; +use std::error::Error; +use std::fmt; + +use crate::Value; + +const DEFAULT_PARAM_DEPTH_LIMIT: usize = 100; + +#[derive(Debug)] +pub enum QueryParserError { + ParameterTypeError(String), + InvalidParameterError(String), + ParamsTooDeepError(String), +} + +impl fmt::Display for QueryParserError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + QueryParserError::ParameterTypeError(msg) => write!(f, "Parameter type error: {}", msg), + QueryParserError::InvalidParameterError(msg) => write!(f, "Invalid parameter: {}", msg), + QueryParserError::ParamsTooDeepError(msg) => write!(f, "Parameters too deep: {}", msg), + } + } +} + +impl Error for QueryParserError {} + +pub struct QueryParser { + param_depth_limit: usize, +} + +impl QueryParser { + pub fn new(param_depth_limit: Option) -> Self { + Self { + param_depth_limit: param_depth_limit.unwrap_or(DEFAULT_PARAM_DEPTH_LIMIT), + } + } + + pub fn parse_nested_query<'a>( + &self, + qs: impl Into>, + ) -> Result, QueryParserError> { + let mut params = HashMap::new(); + self.parse_nested_query_into(&mut params, qs)?; + Ok(params) + } + + pub fn parse_nested_query_into<'a>( + &self, + params: &mut HashMap, + qs: impl Into>, + ) -> Result<(), QueryParserError> { + let qs = qs.into().unwrap_or(""); + + if qs.is_empty() { + return Ok(()); + } + + for pair in qs.split('&') { + if pair.is_empty() { + continue; + } + + let (key, value) = match pair.split_once('=') { + Some((k, v)) => { + let k = form_urlencoded::parse(k.as_bytes()) + .next() + .map(|(k, _)| k.into_owned()) + .unwrap_or_default(); + let v = form_urlencoded::parse(v.as_bytes()) + .next() + .map(|(v, _)| v.into_owned()) + .unwrap_or_default(); + (k, Some(v)) + } + None => { + let k = form_urlencoded::parse(pair.as_bytes()) + .next() + .map(|(k, _)| k.into_owned()) + .unwrap_or_default(); + (k, None) + } + }; + + let value = Value::xstr_opt(value); + self._normalize_params(params, &key, value, 0)?; + } + + Ok(()) + } + + pub fn parse_nested_value<'a>( + &self, + params: &mut HashMap, + key: impl Into>, + value: Value, + ) -> Result<(), QueryParserError> { + let key = key.into().unwrap_or(""); + + if key.is_empty() { + return Ok(()); + } + + self._normalize_params(params, key, value, 0)?; + Ok(()) + } + + fn _normalize_params( + &self, + params: &mut HashMap, + name: &str, + v: Value, + depth: usize, + ) -> Result { + if depth >= self.param_depth_limit { + return Err(QueryParserError::ParamsTooDeepError( + "Parameters nested too deep".to_string(), + )); + } + + let (k, after) = if name.is_empty() { + ("", "") + } else if depth == 0 { + if let Some(start) = name[1..].find('[') { + let start = start + 1; + (&name[..start], &name[start..]) + } else { + (name, "") + } + } else if let Some(stripped) = name.strip_prefix("[]") { + ("[]", stripped) + } else if let Some(stripped) = name.strip_prefix("[") { + if let Some(start) = stripped.find(']') { + (&stripped[..start], &stripped[start + 1..]) + } else { + (name, "") + } + } else { + (name, "") + }; + + if k.is_empty() { + return Ok(Value::Null); + } + + if after.is_empty() { + if k == "[]" && depth != 0 { + return Ok(Value::Array(vec![v])); + } + params.insert(k.to_string(), v); + } else if after == "[" { + params.insert(name.to_string(), v); + } else if after == "[]" { + let entry = params + .entry(k.to_string()) + .or_insert_with(|| Value::Array(Vec::new())); + + if let Value::Array(vec) = entry { + vec.push(v); + } else { + return Err(QueryParserError::ParameterTypeError(format!( + "expected Array (got {}) for param `{}`", + entry.type_name(), + k + ))); + } + } else if let Some(after) = after.strip_prefix("[]") { + // Recognize x[][y] (hash inside array) parameters + let child_key = if !after.starts_with('[') + || !after.ends_with(']') + || after[1..after.len() - 1].contains('[') + || after[1..after.len() - 1].contains(']') + || after[1..after.len() - 1].is_empty() + { + after + } else { + &after[1..after.len() - 1] + }; + + let entry = params + .entry(k.to_string()) + .or_insert_with(|| Value::Array(Vec::new())); + if let Value::Array(vec) = entry { + let mut new_params = HashMap::new(); + if let Some(Value::Object(hash)) = vec.last_mut() { + if !params_hash_has_key(hash, child_key) { + let _ = self._normalize_params(&mut *hash, child_key, v.clone(), depth + 1); + } else { + let normalized = self._normalize_params( + &mut new_params, + child_key, + v.clone(), + depth + 1, + )?; + vec.push(normalized); + } + } else { + let normalized = + self._normalize_params(&mut new_params, child_key, v.clone(), depth + 1)?; + vec.push(normalized); + } + } else { + return Err(QueryParserError::ParameterTypeError(format!( + "expected Array (got {}) for param `{}`", + entry.type_name(), + k + ))); + } + } else { + let entry = params + .entry(k.to_string()) + .or_insert_with(|| Value::Object(HashMap::new())); + + if let Value::Object(hash) = entry { + self._normalize_params(hash, after, v, depth + 1)?; + } else { + return Err(QueryParserError::ParameterTypeError(format!( + "expected Object (got {}) for param `{}`", + entry.type_name(), + k + ))); + } + } + + Ok(Value::Object(params.to_owned())) + } +} + +fn params_hash_has_key(hash: &HashMap, key: &str) -> bool { + if key.contains("[]") { + return false; + } + let parts: Vec<&str> = key + .split(['[', ']']) + .filter(|&part| !part.is_empty()) + .collect(); + + let mut current = hash; + for part in parts { + if let Some(next) = current.get(part) { + if let Value::Object(map) = next { + current = map; + } else { + return true; + } + } else { + return false; + } + } + true +} + +#[cfg(test)] +mod tests { + // Port from: https://github.com/rack/rack/blob/main/test/spec_utils.rb + + use crate::query_parser::{QueryParser, Value, DEFAULT_PARAM_DEPTH_LIMIT}; + use maplit::hashmap; + use pretty_assertions::assert_eq; + use std::collections::HashMap; + + trait ParseTest { + fn should_be(&self, expected: &str); + } + + impl<'a> ParseTest for &'a str { + fn should_be(&self, expected: &str) { + let parser = QueryParser::new(None); + assert_eq!( + Value::Object(parser.parse_nested_query(*self).unwrap()), + convert(expected) + ); + } + } + + fn convert(json: &str) -> Value { + let json: serde_json::Value = serde_json::from_str(json).unwrap(); + Value::from(&json) + } + + fn setup() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[test] + fn parse_nil_as_an_empty_query_string() { + let parser = QueryParser::new(None); + assert_eq!(parser.parse_nested_query(None).unwrap(), HashMap::new()); + } + + #[test] + fn raise_an_exception_if_the_params_are_too_deep() { + let parser = QueryParser::new(Some(DEFAULT_PARAM_DEPTH_LIMIT)); + let deep_string = "[a]".repeat(DEFAULT_PARAM_DEPTH_LIMIT); + let query_string = format!("foo{}=bar", deep_string); + let result = parser.parse_nested_query(&*query_string); + assert!(result.is_err()); + } + + #[test] + fn test_parse_nested_query_strings_correctly() { + setup(); + + "foo".should_be(r#"{"foo": null}"#); + "foo=".should_be(r#"{"foo": ""}"#); + "foo=bar".should_be(r#"{"foo": "bar"}"#); + "foo=\"bar\"".should_be(r#"{"foo": "\"bar\""}"#); + + "foo=bar&foo=quux".should_be(r#"{"foo": "quux"}"#); + "foo&foo=".should_be(r#"{"foo": ""}"#); + "foo=1&bar=2".should_be(r#"{"foo": "1", "bar": "2"}"#); + "&foo=1&&bar=2".should_be(r#"{"foo": "1", "bar": "2"}"#); + "foo&bar=".should_be(r#"{"foo": null, "bar": ""}"#); + "foo=bar&baz=".should_be(r#"{"foo": "bar", "baz": ""}"#); + "&foo=1&&bar=2".should_be(r#"{"foo": "1", "bar": "2"}"#); + "foo&bar=".should_be(r#"{"foo": null, "bar": ""}"#); + "foo=bar&baz=".should_be(r#"{"foo": "bar", "baz": ""}"#); + "my+weird+field=q1%212%22%27w%245%267%2Fz8%29%3F" + .should_be(r#"{"my weird field": "q1!2\"'w$5&7/z8)?"}"#); + + "a=b&pid%3D1234=1023".should_be(r#"{"pid=1234": "1023", "a": "b"}"#); + + "foo[]".should_be(r#"{"foo": [null]}"#); + "foo[]=".should_be(r#"{"foo": [""]}"#); + "foo[]=bar".should_be(r#"{"foo": ["bar"]}"#); + "foo[]=bar&foo".should_be(r#"{"foo": null}"#); + "foo[]=bar&foo[".should_be(r#"{"foo": ["bar"], "foo[": null}"#); + "foo[]=bar&foo[=baz".should_be(r#"{"foo": ["bar"], "foo[": "baz"}"#); + "foo[]=bar&foo[]".should_be(r#"{"foo": ["bar", null]}"#); + "foo[]=bar&foo[]=".should_be(r#"{"foo": ["bar", ""]}"#); + + "foo[]=1&foo[]=2".should_be(r#"{"foo": ["1", "2"]}"#); + "foo=bar&baz[]=1&baz[]=2&baz[]=3".should_be(r#"{"foo": "bar", "baz": ["1", "2", "3"]}"#); + "foo[]=bar&baz[]=1&baz[]=2&baz[]=3" + .should_be(r#"{"foo": ["bar"], "baz": ["1", "2", "3"]}"#); + + "x[y][z]".should_be(r#"{"x": { "y": { "z": null } }}"#); + "x[y][z]=1".should_be(r#"{"x": { "y": { "z": "1"} }}"#); + "x[y][z][]=1".should_be(r#"{"x": { "y": { "z": ["1"] } }}"#); + "x[y][z]=1&x[y][z]=2".should_be(r#"{"x": { "y": { "z": "2"} }}"#); + "x[y][z][]=1&x[y][z][]=2".should_be(r#"{"x": { "y": { "z": ["1", "2"] } }}"#); + + "x[y][][z]=1".should_be(r#"{"x": { "y": [{ "z": "1" }] }}"#); + "x[y][][z][]=1".should_be(r#"{"x": { "y": [{ "z": ["1"] }] }}"#); + "x[y][][z]=1&x[y][][w]=2".should_be(r#"{"x": { "y": [{ "z": "1", "w": "2" }] }}"#); + + "x[y][][v][w]=1".should_be(r#"{"x": { "y": [{ "v": { "w": "1" } }] }}"#); + "x[y][][z]=1&x[y][][v][w]=2" + .should_be(r#"{"x": { "y": [{ "z": "1", "v": { "w": "2" } }] }}"#); + + "x[y][][z]=1&x[y][][z]=2".should_be(r#"{"x": { "y": [{ "z": "1" }, { "z": "2" }] }}"#); + "x[y][][z]=1&x[y][][w]=a&x[y][][z]=2&x[y][][w]=3" + .should_be(r#"{"x": { "y": [{ "z": "1", "w": "a" }, { "z": "2", "w": "3" }] }}"#); + + "x[][y]=1&x[][z][w]=a&x[][y]=2&x[][z][w]=b".should_be( + r#"{"x": [{ "y": "1", "z": { "w": "a" } }, { "y": "2", "z": { "w": "b" } }]}"#, + ); + "x[][z][w]=a&x[][y]=1&x[][z][w]=b&x[][y]=2".should_be( + r#"{"x": [{ "y": "1", "z": { "w": "a" } }, { "y": "2", "z": { "w": "b" } }]}"#, + ); + + "data[books][][data][page]=1&data[books][][data][page]=2".should_be( + r#"{"data": { "books": [{ "data": { "page": "1" } }, { "data": { "page": "2" } }] }}"#, + ) + } + + #[test] + fn test_parse_empty() { + let parser = QueryParser::new(None); + assert_eq!(parser.parse_nested_query("").unwrap(), HashMap::new()); + assert_eq!(parser.parse_nested_query(None).unwrap(), HashMap::new()); + } + + #[test] + fn test_parse_empty_key_value() { + let parser = QueryParser::new(None); + + // Test empty key with value + assert_eq!(parser.parse_nested_query("=value").unwrap(), hashmap! {}); + + // Test key with empty value + assert_eq!( + parser.parse_nested_query("key=").unwrap(), + hashmap! { + "key".to_string() => Value::xstr("") + } + ); + + // Test empty key-value pair + assert_eq!(parser.parse_nested_query("=").unwrap(), hashmap! {}); + + // Test key without value + assert_eq!( + parser.parse_nested_query("&key&").unwrap(), + hashmap! { + "key".to_string() => Value::Null + } + ); + } + + #[test] + fn test_parse_duplicate_keys() { + let parser = QueryParser::new(None); + + // Test duplicate keys (last value wins) + assert_eq!( + parser.parse_nested_query("foo=bar&foo=quux").unwrap(), + hashmap! { + "foo".to_string() => Value::xstr("quux") + } + ); + + // Test key without value followed by key with value + assert_eq!( + parser.parse_nested_query("foo&foo=").unwrap(), + hashmap! { + "foo".to_string() => Value::xstr("") + } + ); + + // Test key with value followed by key without value + assert_eq!( + parser.parse_nested_query("foo=bar&foo").unwrap(), + hashmap! { + "foo".to_string() => Value::Null + } + ); + } + + #[test] + fn test_parse_array_edge_cases() { + setup(); + let parser = QueryParser::new(None); + + // Test array followed by plain key + assert_eq!( + parser.parse_nested_query("foo[]=bar&foo").unwrap(), + hashmap! { + "foo".to_string() => Value::Null + } + ); + + // Test array followed by incomplete array syntax + assert_eq!( + parser.parse_nested_query("foo[]=bar&foo[").unwrap(), + hashmap! { + "foo".to_string() => Value::Array(vec![Value::xstr("bar")]), + "foo[".to_string() => Value::Null + } + ); + + // Test array followed by incomplete array with value + assert_eq!( + parser.parse_nested_query("foo[]=bar&foo[=baz").unwrap(), + hashmap! { + "foo".to_string() => Value::Array(vec![Value::xstr("bar")]), + "foo[".to_string() => Value::xstr("baz") + } + ); + } + + #[test] + // can parse a query string with a key that has invalid UTF-8 encoded bytes + fn test_parse_invalid_utf8() { + let parser = QueryParser::new(None); + let result = parser.parse_nested_query("foo%81E=1").unwrap_or_default(); + assert_eq!(result.len(), 1); + let key = result.keys().next().unwrap().as_bytes(); + assert_eq!(key, b"foo\xEF\xBF\xBDE"); + } + + #[test] + fn only_moves_to_a_new_array_when_the_full_key_has_been_seen() { + "x[][y][][z]=1&x[][y][][w]=2".should_be(r#"{"x": [{ "y": [{ "z": "1", "w": "2" }] }]}"#); + "x[][id]=1&x[][y][a]=5&x[][y][b]=7&x[][z][id]=3&x[][z][w]=0&x[][id]=2&x[][y][a]=6&x[][y][b]=8&x[][z][id]=4&x[][z][w]=0" + .should_be( + r#" + { + "x": [ + { "id": "1", "y": { "a": "5", "b": "7" }, "z": { "id": "3", "w": "0" } }, + { "id": "2", "y": { "a": "6", "b": "8" }, "z": { "id": "4", "w": "0" } } + ] + }"#, + ); + } + + #[test] + fn handles_unexpected_use_of_brackets_in_parameter_keys_as_normal_characters() { + "[]=1&[a]=2&b[=3&c]=4".should_be(r#"{"[]": "1", "[a]": "2", "b[": "3", "c]": "4"}"#); + "d[[]=5&e][]=6&f[[]]=7" + .should_be(r#"{"d": {"[": "5"}, "e]": ["6"], "f": { "[": { "]": "7" } }}"#); + "g[h]i=8&j[k]l[m]=9" + .should_be(r#"{"g": { "h": { "i": "8" } }, "j": { "k": { "l[m]": "9" } }}"#); + "l[[[[[[[[]]]]]]]=10".should_be(r#"{"l": {"[[[[[[[": {"]]]]]]": "10"}}}"#); + } +} diff --git a/src/serde.rs b/src/serde.rs index ca79fdf..f4fd294 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -1,6 +1,6 @@ use crate::{Number, N}; -use super::ParamsValue; +use super::Value; use log::debug; use serde::{ de::{self, MapAccess, SeqAccess, Visitor}, @@ -11,41 +11,41 @@ use std::collections::HashMap; struct ParamsValueVisitor; impl<'de> Visitor<'de> for ParamsValueVisitor { - type Value = ParamsValue; + type Value = Value; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("any valid JSON value or upload file") } fn visit_bool(self, v: bool) -> Result { - Ok(ParamsValue::Bool(v)) + Ok(Value::Bool(v)) } fn visit_i64(self, v: i64) -> Result { - Ok(ParamsValue::Number(Number::from(v))) + Ok(Value::Number(Number::from(v))) } fn visit_u64(self, v: u64) -> Result { - Ok(ParamsValue::Number(Number::from(v))) + Ok(Value::Number(Number::from(v))) } fn visit_f64(self, v: f64) -> Result { - Ok(ParamsValue::Number(Number::from(v))) + Ok(Value::Number(Number::from(v))) } fn visit_str(self, v: &str) -> Result where E: de::Error, { - Ok(ParamsValue::Convertible(v.to_owned())) + Ok(Value::XStr(v.to_owned())) } fn visit_string(self, v: String) -> Result { - Ok(ParamsValue::Convertible(v)) + Ok(Value::XStr(v)) } fn visit_none(self) -> Result { - Ok(ParamsValue::Null) + Ok(Value::Null) } fn visit_some(self, deserializer: D) -> Result @@ -56,7 +56,7 @@ impl<'de> Visitor<'de> for ParamsValueVisitor { } fn visit_unit(self) -> Result { - Ok(ParamsValue::Null) + Ok(Value::Null) } fn visit_seq(self, mut seq: A) -> Result @@ -67,7 +67,7 @@ impl<'de> Visitor<'de> for ParamsValueVisitor { while let Some(elem) = seq.next_element()? { vec.push(elem); } - Ok(ParamsValue::Array(vec)) + Ok(Value::Array(vec)) } fn visit_map(self, mut map: A) -> Result @@ -78,11 +78,11 @@ impl<'de> Visitor<'de> for ParamsValueVisitor { while let Some((key, value)) = map.next_entry()? { values.insert(key, value); } - Ok(ParamsValue::Object(values)) + Ok(Value::Object(values)) } } -impl<'de> Deserialize<'de> for ParamsValue { +impl<'de> Deserialize<'de> for Value { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, @@ -92,12 +92,12 @@ impl<'de> Deserialize<'de> for ParamsValue { } struct MapAccessor { - map: std::collections::hash_map::IntoIter, - current_value: Option, + map: std::collections::hash_map::IntoIter, + current_value: Option, } impl MapAccessor { - fn new(map: HashMap) -> Self { + fn new(map: HashMap) -> Self { MapAccessor { map: map.into_iter(), current_value: None, @@ -133,7 +133,7 @@ impl<'de> MapAccess<'de> for MapAccessor { } struct SeqAccessor { - seq: std::vec::IntoIter, + seq: std::vec::IntoIter, } impl<'de> SeqAccess<'de> for SeqAccessor { @@ -150,7 +150,7 @@ impl<'de> SeqAccess<'de> for SeqAccessor { } } -impl<'de> Deserializer<'de> for ParamsValue { +impl<'de> Deserializer<'de> for Value { type Error = serde::de::value::Error; fn deserialize_any(self, visitor: V) -> Result @@ -158,29 +158,29 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Null => visitor.visit_unit(), - ParamsValue::Bool(b) => visitor.visit_bool(b), - ParamsValue::Number(Number(n)) => match n { + Value::Null => visitor.visit_unit(), + Value::Bool(b) => visitor.visit_bool(b), + Value::Number(Number(n)) => match n { N::PosInt(i) => visitor.visit_u64(i), N::NegInt(i) => visitor.visit_i64(i), N::Float(f) => visitor.visit_f64(f), }, - ParamsValue::String(s) => visitor.visit_string(s), - ParamsValue::Object(map) => visitor.visit_map(MapAccessor::new(map)), - ParamsValue::Array(vec) => visitor.visit_seq(SeqAccessor { + Value::String(s) => visitor.visit_string(s), + Value::Object(map) => visitor.visit_map(MapAccessor::new(map)), + Value::Array(vec) => visitor.visit_seq(SeqAccessor { seq: vec.into_iter(), }), - ParamsValue::Convertible(s) => visitor.visit_string(s), - ParamsValue::UploadFile(file) => { + Value::XStr(s) => visitor.visit_string(s), + Value::UploadFile(file) => { let map = HashMap::from([ - ("name".to_string(), ParamsValue::String(file.name.clone())), + ("name".to_string(), Value::String(file.name.clone())), ( "content_type".to_string(), - ParamsValue::String(file.content_type.clone()), + Value::String(file.content_type.clone()), ), ( "temp_file_path".to_string(), - ParamsValue::String(file.temp_file_path.to_string()), + Value::String(file.temp_file_path.to_string()), ), ]); visitor.visit_map(MapAccessor::new(map)) @@ -193,7 +193,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Convertible(s) => match s.to_lowercase().as_str() { + Value::XStr(s) => match s.to_lowercase().as_str() { "true" | "1" | "on" | "yes" => visitor.visit_bool(true), "false" | "0" | "off" | "no" => visitor.visit_bool(false), _ => Err(de::Error::custom("invalid boolean value")), @@ -207,7 +207,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_i8(v)), @@ -220,7 +220,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_i16(v)), @@ -233,7 +233,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_i32(v)), @@ -247,7 +247,7 @@ impl<'de> Deserializer<'de> for ParamsValue { { debug!("deserialize_i64 self: {:?}", self); match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_i64(v)), @@ -260,7 +260,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_u8(v)), @@ -273,7 +273,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_u16(v)), @@ -286,7 +286,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_u32(v)), @@ -299,7 +299,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_u64(v)), @@ -313,7 +313,7 @@ impl<'de> Deserializer<'de> for ParamsValue { { debug!("deserialize_f32 self: {:?}", self); match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_f32(v)), @@ -327,7 +327,7 @@ impl<'de> Deserializer<'de> for ParamsValue { { debug!("deserialize_f64 self: {:?}", self); match self { - ParamsValue::Convertible(s) => s + Value::XStr(s) => s .parse() .map_err(de::Error::custom) .and_then(|v| visitor.visit_f64(v)), @@ -340,7 +340,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Convertible(s) => { + Value::XStr(s) => { let mut chars = s.chars(); match (chars.next(), chars.next()) { (Some(c), None) => visitor.visit_char(c), @@ -356,7 +356,7 @@ impl<'de> Deserializer<'de> for ParamsValue { V: Visitor<'de>, { match self { - ParamsValue::Null => visitor.visit_none(), + Value::Null => visitor.visit_none(), _ => visitor.visit_some(self), } } diff --git a/src/tests.rs b/src/tests.rs index 57d892c..d8ba707 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -256,199 +256,6 @@ async fn test_combined_params() { assert_eq!(params.extra, Some("query_param".to_string())); } -#[test] -fn test_process_nested_params() { - let mut input = HashMap::new(); - - // Test simple key-value - input.insert( - "name".to_string(), - vec![ParamsValue::Convertible("john".to_string())], - ); - - // Test nested object - input.insert( - "user[name]".to_string(), - vec![ParamsValue::Convertible("mary".to_string())], - ); - input.insert( - "user[age]".to_string(), - vec![ParamsValue::Convertible("25".to_string())], - ); - input.insert( - "user[address][city]".to_string(), - vec![ParamsValue::Convertible("beijing".to_string())], - ); - input.insert( - "user[address][country]".to_string(), - vec![ParamsValue::Convertible("china".to_string())], - ); - - // Test array - input.insert( - "colors[]".to_string(), - vec![ - ParamsValue::Convertible("red".to_string()), - ParamsValue::Convertible("blue".to_string()), - ], - ); - - // Test indexed array - input.insert( - "numbers[0]".to_string(), - vec![ParamsValue::Convertible("1".to_string())], - ); - input.insert( - "numbers[1]".to_string(), - vec![ParamsValue::Convertible("2".to_string())], - ); - input.insert( - "numbers[2]".to_string(), - vec![ParamsValue::Convertible("3".to_string())], - ); - - // Test array of objects - input.insert( - "users[0][name]".to_string(), - vec![ParamsValue::Convertible("john".to_string())], - ); - input.insert( - "users[0][age]".to_string(), - vec![ParamsValue::Convertible("20".to_string())], - ); - input.insert( - "users[1][name]".to_string(), - vec![ParamsValue::Convertible("mary".to_string())], - ); - input.insert( - "users[1][age]".to_string(), - vec![ParamsValue::Convertible("25".to_string())], - ); - - let result = process_nested_params(input); - debug!("result: {:?}", result); - - // Verify the result - if let ParamsValue::Object(map) = result { - // Test simple key-value - assert_eq!( - map.get("name").unwrap(), - &ParamsValue::Convertible("john".to_string()) - ); - - // Test nested object - if let ParamsValue::Object(user) = map.get("user").unwrap() { - assert_eq!( - user.get("name").unwrap(), - &ParamsValue::Convertible("mary".to_string()) - ); - assert_eq!( - user.get("age").unwrap(), - &ParamsValue::Convertible("25".to_string()) - ); - - if let ParamsValue::Object(address) = user.get("address").unwrap() { - assert_eq!( - address.get("city").unwrap(), - &ParamsValue::Convertible("beijing".to_string()) - ); - assert_eq!( - address.get("country").unwrap(), - &ParamsValue::Convertible("china".to_string()) - ); - } else { - panic!("address should be an object"); - } - } else { - panic!("user should be an object"); - } - - // Test array - if let ParamsValue::Array(colors) = map.get("colors").unwrap() { - assert_eq!(colors.len(), 2); - assert_eq!(colors[0], ParamsValue::Convertible("red".to_string())); - assert_eq!(colors[1], ParamsValue::Convertible("blue".to_string())); - } else { - panic!("colors should be an array"); - } - - // Test indexed array - if let ParamsValue::Array(numbers) = map.get("numbers").unwrap() { - assert_eq!(numbers.len(), 3); - assert_eq!(numbers[0], ParamsValue::Convertible("1".to_string())); - assert_eq!(numbers[1], ParamsValue::Convertible("2".to_string())); - assert_eq!(numbers[2], ParamsValue::Convertible("3".to_string())); - } else { - panic!("numbers should be an array"); - } - - // Test array of objects - if let ParamsValue::Array(users) = map.get("users").unwrap() { - assert_eq!(users.len(), 2); - - if let ParamsValue::Object(user0) = &users[0] { - assert_eq!( - user0.get("name").unwrap(), - &ParamsValue::Convertible("john".to_string()) - ); - assert_eq!( - user0.get("age").unwrap(), - &ParamsValue::Convertible("20".to_string()) - ); - } else { - panic!("users[0] should be an object"); - } - - if let ParamsValue::Object(user1) = &users[1] { - assert_eq!( - user1.get("name").unwrap(), - &ParamsValue::Convertible("mary".to_string()) - ); - assert_eq!( - user1.get("age").unwrap(), - &ParamsValue::Convertible("25".to_string()) - ); - } else { - panic!("users[1] should be an object"); - } - } else { - panic!("users should be an array"); - } - } else { - panic!("result should be an object"); - } -} - -#[test] -fn test_process_nested_with_empty_array() { - let mut input = HashMap::new(); - - // Test array with empty values - input.insert( - "colors[]".to_string(), - vec![ - ParamsValue::Convertible("red".to_string()), - ParamsValue::Convertible("blue".to_string()), - ], - ); - - let result = process_nested_params(input); - - // Verify the result - if let ParamsValue::Object(map) = result { - // Test array - if let ParamsValue::Array(colors) = map.get("colors").unwrap() { - assert_eq!(colors.len(), 2); - assert_eq!(colors[0], ParamsValue::Convertible("red".to_string())); - assert_eq!(colors[1], ParamsValue::Convertible("blue".to_string())); - } else { - panic!("colors should be an array"); - } - } else { - panic!("result should be an object"); - } -} - #[tokio::test] async fn test_nested_params_with_file_upload() { let app = Router::new().route("/api/posts", post(test_nested_params_handler)); @@ -479,16 +286,16 @@ async fn test_nested_params_with_file_upload() { .file_name("cover.jpg") .mime_type("image/jpeg"), ) - .add_text("attachments[0][name]", "First attachment") + .add_text("attachments[][name]", "First attachment") .add_part( - "attachments[0][file]", + "attachments[][file]", Part::bytes(attachment1_content.to_vec()) .file_name("attachment1.txt") .mime_type("text/plain"), ) - .add_text("attachments[1][name]", "Second attachment") + .add_text("attachments[][name]", "Second attachment") .add_part( - "attachments[1][file]", + "attachments[][file]", Part::bytes(attachment2_content.to_vec()) .file_name("attachment2.txt") .mime_type("text/plain"), @@ -641,26 +448,26 @@ async fn test_mixed_create_post() { .add_part("metadata[created_at]", Part::text("2024-12-29")) // Add first attachment with file and metadata .add_part( - "attachments[0][file]", + "attachments[][file]", Part::bytes(vec![1, 2, 3, 4]) .file_name("test1.bin") .mime_type("application/octet-stream"), ) - .add_part("attachments[0][name]", Part::text("Test Attachment 1")) + .add_part("attachments[][name]", Part::text("Test Attachment 1")) .add_part( - "attachments[0][description]", + "attachments[][description]", Part::text("First test attachment"), ) // Add second attachment with file and metadata .add_part( - "attachments[1][file]", + "attachments[][file]", Part::bytes(vec![5, 6, 7, 8, 9]) .file_name("test2.bin") .mime_type("application/octet-stream"), ) - .add_part("attachments[1][name]", Part::text("Test Attachment 2")) + .add_part("attachments[][name]", Part::text("Test Attachment 2")) .add_part( - "attachments[1][description]", + "attachments[][description]", Part::text("Second test attachment"), ); @@ -1032,3 +839,93 @@ async fn test_query_params_numbers() { assert!((params.0.small_float - 0.0000123).abs() < f64::EPSILON); assert!((params.0.exp_num - 123000.0).abs() < f64::EPSILON); } + +#[derive(Debug, Deserialize)] +struct TestEncodedParams { + #[serde(rename = "foo=1")] + foo: Option, + baz: Option, +} + +#[tokio::test] +async fn test_encoded_path_params() { + setup(); + + let req = Request::builder() + .method(http::Method::GET) + .uri("/test?foo%3D1=bar&baz=qux%3D2") + .body(Body::empty()) + .unwrap(); + + let Params(params, _) = Params::::from_request(req, &()) + .await + .unwrap(); + assert_eq!(params.foo, Some("bar".to_string())); + assert_eq!(params.baz, Some("qux=2".to_string())); +} + +#[tokio::test] +async fn test_json_params() { + setup(); + let req = Request::builder() + .method(http::Method::POST) + .header(http::header::CONTENT_TYPE, "application/json") + .uri("/test") + .body(Body::new( + json!({ + "foo=1": "bar", + "baz": "qux=2" + }) + .to_string(), + )) + .unwrap(); + + let Params(params, _) = Params::::from_request(req, &()) + .await + .unwrap(); + assert_eq!(params.foo, Some("bar".to_string())); + assert_eq!(params.baz, Some("qux=2".to_string())); +} + +#[tokio::test] +async fn test_json_params_dont_decode() { + setup(); + let req = Request::builder() + .method(http::Method::POST) + .header(http::header::CONTENT_TYPE, "application/json") + .uri("/test") + .body(Body::new( + json!({ + "foo%3D1": "bar", + "baz": "qux%3D2" + }) + .to_string(), + )) + .unwrap(); + + let Params(params, _) = Params::::from_request(req, &()) + .await + .unwrap(); + assert_eq!(params.foo, None); + assert_eq!(params.baz, Some("qux%3D2".to_string())); +} + +#[tokio::test] +async fn test_encoded_form_params() { + setup(); + let req = Request::builder() + .method(http::Method::POST) + .header( + http::header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .uri("/test") + .body(Body::new("foo%3D1=bar&baz=qux%3D2".to_string())) + .unwrap(); + + let Params(params, _) = Params::::from_request(req, &()) + .await + .unwrap(); + assert_eq!(params.foo, Some("bar".to_string())); + assert_eq!(params.baz, Some("qux=2".to_string())); +} diff --git a/src/value.rs b/src/value.rs index e094acf..70ddd40 100644 --- a/src/value.rs +++ b/src/value.rs @@ -42,53 +42,205 @@ impl From for Number { } } +pub trait IntoNumber { + fn into_number(self) -> Number; +} + +impl IntoNumber for u64 { + fn into_number(self) -> Number { + Number::from(self) + } +} + +impl IntoNumber for i64 { + fn into_number(self) -> Number { + Number::from(self) + } +} + +impl IntoNumber for f64 { + fn into_number(self) -> Number { + Number::from(self) + } +} + #[derive(Debug, Clone)] -pub enum ParamsValue { +pub enum Value { Null, Bool(bool), Number(Number), String(String), - Convertible(String), - Object(HashMap), - Array(Vec), + XStr(String), + Object(HashMap), + Array(Vec), UploadFile(UploadFile), } -impl PartialEq for ParamsValue { +impl PartialEq for Value { fn eq(&self, other: &Self) -> bool { match (self, other) { (Self::Null, Self::Null) => true, + (Self::XStr(a), b) => match b { + Self::XStr(b) => a == b, + Self::String(b) => a == b, + _ => false, + }, + (a, Self::XStr(b)) => match a { + Self::XStr(a) => a == b, + Self::String(a) => a == b, + _ => false, + }, (Self::Bool(a), Self::Bool(b)) => a == b, (Self::Number(a), Self::Number(b)) => a == b, (Self::String(a), Self::String(b)) => a == b, (Self::Object(a), Self::Object(b)) => a == b, (Self::Array(a), Self::Array(b)) => a == b, - (Self::Convertible(a), Self::Convertible(b)) => a == b, (Self::UploadFile(a), Self::UploadFile(b)) => a == b, _ => false, } } } -fn event_to_params_value(event: &JsonEvent, parser: &JsonParser) -> ParamsValue { - match event { - JsonEvent::ValueString => { - ParamsValue::Convertible(parser.current_str().unwrap().to_string()) +impl From<&serde_json::Value> for Value { + fn from(v: &serde_json::Value) -> Self { + match v { + serde_json::Value::Null => Value::Null, + serde_json::Value::Bool(v) => Value::Bool(*v), + serde_json::Value::Number(n) => { + let n = n.as_f64().unwrap(); + if n.is_nan() { + Value::Null + } else { + Value::Number(n.into()) + } + } + serde_json::Value::String(v) => Value::String(v.clone()), + serde_json::Value::Array(v) => { + Value::Array(v.iter().map(Value::from).collect::>()) + } + serde_json::Value::Object(v) => Value::Object( + v.iter() + .map(|(k, v)| (k.clone(), Value::from(v))) + .collect::>(), + ), + } + } +} + +impl Value { + pub fn merge(self, other: Value) -> Result { + match (self, other) { + // Object + Object = Merged object + (Value::Object(mut a), Value::Object(b)) => { + a.extend(b); + Ok(Value::Object(a)) + } + // Array + Array = Combined array + (Value::Array(mut a), Value::Array(b)) => { + a.extend(b); + Ok(Value::Array(a)) + } + // Array + Any = Array with new element + (Value::Array(mut a), other) => { + a.push(other); + Ok(Value::Array(a)) + } + // Any + Array = Array with new element at start + (value, Value::Array(mut arr)) => { + arr.insert(0, value); + Ok(Value::Array(arr)) + } + // Null + Any = Any + (Value::Null, other) => Ok(other), + // Any + Null = Any + (value, Value::Null) => Ok(value), + // Incompatible types + (a, b) => Err(Error::MergeError(format!( + "Cannot merge {} with {}", + a.type_name(), + b.type_name() + ))), + } + } + + pub fn merge_into( + self, + mut a: HashMap, + ) -> Result, Error> { + match self { + Value::Object(b) => { + a.extend(b); + Ok(a) + } + _ => Err(Error::MergeError(format!( + "Cannot merge {} with object", + self.type_name() + ))), + } + } + + pub fn xstr>(v: T) -> Value { + Value::XStr(v.into()) + } + + pub fn xstr_opt>(v: Option) -> Value { + match v { + Some(v) => Value::XStr(v.into()), + None => Value::Null, } - JsonEvent::ValueInt => ParamsValue::Number(Number::from( + } + + pub fn number(v: T) -> Value { + Value::Number(v.into_number()) + } + + pub fn bool(v: bool) -> Value { + Value::Bool(v) + } + + pub fn null() -> Value { + Value::Null + } + + pub fn array(v: Vec) -> Value { + Value::Array(v) + } + + pub fn object(v: HashMap) -> Value { + Value::Object(v) + } + + pub fn type_name(&self) -> &'static str { + match self { + Value::Null => "null", + Value::Bool(_) => "bool", + Value::Number(_) => "number", + Value::String(_) => "string", + Value::Object(_) => "object", + Value::Array(_) => "array", + Value::XStr(_) => "string", + Value::UploadFile(_) => "file", + } + } +} + +fn event_to_params_value(event: &JsonEvent, parser: &JsonParser) -> Value { + match event { + JsonEvent::ValueString => Value::XStr(parser.current_str().unwrap().to_string()), + JsonEvent::ValueInt => Value::Number(Number::from( parser.current_str().unwrap().parse::().unwrap(), )), - JsonEvent::ValueFloat => ParamsValue::Number(Number::from( + JsonEvent::ValueFloat => Value::Number(Number::from( parser.current_str().unwrap().parse::().unwrap(), )), - JsonEvent::ValueTrue => ParamsValue::Convertible("true".to_string()), - JsonEvent::ValueFalse => ParamsValue::Convertible("false".to_string()), - JsonEvent::ValueNull => ParamsValue::Null, + JsonEvent::ValueTrue => Value::XStr("true".to_string()), + JsonEvent::ValueFalse => Value::XStr("false".to_string()), + JsonEvent::ValueNull => Value::Null, _ => unreachable!(), } } -pub fn parse_json(feeder: SliceJsonFeeder) -> Result { +pub fn parse_json(feeder: SliceJsonFeeder) -> Result { let mut parser = JsonParser::new(feeder); let mut stack = vec![]; @@ -102,9 +254,9 @@ pub fn parse_json(feeder: SliceJsonFeeder) -> Result { JsonEvent::StartObject | JsonEvent::StartArray => { let v = if event == JsonEvent::StartObject { - ParamsValue::Object(HashMap::new()) + Value::Object(HashMap::new()) } else { - ParamsValue::Array(vec![]) + Value::Array(vec![]) }; stack.push((current_key.take(), v)); } @@ -113,12 +265,12 @@ pub fn parse_json(feeder: SliceJsonFeeder) -> Result { let v = stack.pop().unwrap(); if let Some((_, top)) = stack.last_mut() { match top { - ParamsValue::Object(o) => { + Value::Object(o) => { if let Some(key) = v.0 { o.insert(key, v.1); } } - ParamsValue::Array(a) => { + Value::Array(a) => { a.push(v.1); } _ => return Err(JsonError::SyntaxError), @@ -142,10 +294,10 @@ pub fn parse_json(feeder: SliceJsonFeeder) -> Result { let v = event_to_params_value(&event, &parser); if let Some((_, top)) = stack.last_mut() { match top { - ParamsValue::Array(a) => { + Value::Array(a) => { a.push(v); } - ParamsValue::Object(o) => { + Value::Object(o) => { if let Some(key) = current_key.take() { o.insert(key, v); } else { @@ -185,26 +337,6 @@ impl From for Error { } } -pub fn merge_json( - feeder: SliceJsonFeeder, - merged: &mut HashMap>, -) -> Result<(), JsonError> { - let value = parse_json(feeder)?; - debug!("Parsed JSON value: {:#?}", value); - match value { - ParamsValue::Object(obj) => { - for (key, value) in obj { - merged.insert(key, vec![value]); - } - } - _ => { - merged.insert("".to_string(), vec![value]); - } - } - debug!("Final merged map: {:#?}", merged); - Ok(()) -} - #[cfg(test)] mod tests { use super::*; @@ -272,18 +404,12 @@ mod tests { // Test positive integers let json = r#"{"pos": 42, "zero": 0, "big": 9007199254740991}"#; let result = parse_json(SliceJsonFeeder::new(json.as_bytes())).unwrap(); - if let ParamsValue::Object(map) = result { - assert!(matches!( - map["pos"], - ParamsValue::Number(Number(N::PosInt(42))) - )); - assert!(matches!( - map["zero"], - ParamsValue::Number(Number(N::PosInt(0))) - )); + if let Value::Object(map) = result { + assert!(matches!(map["pos"], Value::Number(Number(N::PosInt(42))))); + assert!(matches!(map["zero"], Value::Number(Number(N::PosInt(0))))); assert!(matches!( map["big"], - ParamsValue::Number(Number(N::PosInt(9007199254740991))) + Value::Number(Number(N::PosInt(9007199254740991))) )); } else { panic!("Expected object"); @@ -292,14 +418,11 @@ mod tests { // Test negative integers let json = r#"{"neg": -42, "min": -9007199254740991}"#; let result = parse_json(SliceJsonFeeder::new(json.as_bytes())).unwrap(); - if let ParamsValue::Object(map) = result { - assert!(matches!( - map["neg"], - ParamsValue::Number(Number(N::NegInt(-42))) - )); + if let Value::Object(map) = result { + assert!(matches!(map["neg"], Value::Number(Number(N::NegInt(-42))))); assert!(matches!( map["min"], - ParamsValue::Number(Number(N::NegInt(-9007199254740991))) + Value::Number(Number(N::NegInt(-9007199254740991))) )); } else { panic!("Expected object"); @@ -314,21 +437,21 @@ mod tests { "neg_exp": -1.23e-5 }"#; let result = parse_json(SliceJsonFeeder::new(json.as_bytes())).unwrap(); - if let ParamsValue::Object(map) = result { + if let Value::Object(map) = result { assert!( - matches!(map["float"], ParamsValue::Number(Number(N::Float(v))) if (v - 42.5).abs() < f64::EPSILON) + matches!(map["float"], Value::Number(Number(N::Float(v))) if (v - 42.5).abs() < f64::EPSILON) ); assert!( - matches!(map["neg_float"], ParamsValue::Number(Number(N::Float(v))) if (v - (-42.5)).abs() < f64::EPSILON) + matches!(map["neg_float"], Value::Number(Number(N::Float(v))) if (v - (-42.5)).abs() < f64::EPSILON) ); assert!( - matches!(map["zero_float"], ParamsValue::Number(Number(N::Float(v))) if v.abs() < f64::EPSILON) + matches!(map["zero_float"], Value::Number(Number(N::Float(v))) if v.abs() < f64::EPSILON) ); assert!( - matches!(map["exp"], ParamsValue::Number(Number(N::Float(v))) if (v - 123000.0).abs() < f64::EPSILON) + matches!(map["exp"], Value::Number(Number(N::Float(v))) if (v - 123000.0).abs() < f64::EPSILON) ); assert!( - matches!(map["neg_exp"], ParamsValue::Number(Number(N::Float(v))) if (v - (-0.0000123)).abs() < f64::EPSILON) + matches!(map["neg_exp"], Value::Number(Number(N::Float(v))) if (v - (-0.0000123)).abs() < f64::EPSILON) ); } else { panic!("Expected object"); @@ -337,19 +460,14 @@ mod tests { // Test array of numbers let json = r#"[42, -42, 42.5, 0, -0.0]"#; let result = parse_json(SliceJsonFeeder::new(json.as_bytes())).unwrap(); - if let ParamsValue::Array(arr) = result { - assert!(matches!(arr[0], ParamsValue::Number(Number(N::PosInt(42))))); - assert!(matches!( - arr[1], - ParamsValue::Number(Number(N::NegInt(-42))) - )); - assert!( - matches!(arr[2], ParamsValue::Number(Number(N::Float(v))) if (v - 42.5).abs() < f64::EPSILON) - ); - assert!(matches!(arr[3], ParamsValue::Number(Number(N::PosInt(0))))); + if let Value::Array(arr) = result { + assert!(matches!(arr[0], Value::Number(Number(N::PosInt(42))))); + assert!(matches!(arr[1], Value::Number(Number(N::NegInt(-42))))); assert!( - matches!(arr[4], ParamsValue::Number(Number(N::Float(v))) if v.abs() < f64::EPSILON) + matches!(arr[2], Value::Number(Number(N::Float(v))) if (v - 42.5).abs() < f64::EPSILON) ); + assert!(matches!(arr[3], Value::Number(Number(N::PosInt(0))))); + assert!(matches!(arr[4], Value::Number(Number(N::Float(v))) if v.abs() < f64::EPSILON)); } else { panic!("Expected array"); } @@ -366,32 +484,26 @@ mod tests { "nested": {"a": 1, "b": 2} }"#; let result = parse_json(SliceJsonFeeder::new(json.as_bytes())).unwrap(); - if let ParamsValue::Object(map) = result { + if let Value::Object(map) = result { assert!(matches!( map["number"], - ParamsValue::Number(Number(N::PosInt(42))) + Value::Number(Number(N::PosInt(42))) )); - assert!(matches!(map["string"], ParamsValue::Convertible(ref s) if s == "hello")); - assert!(matches!(map["bool"], ParamsValue::Convertible(ref s) if s == "true")); - assert!(matches!(map["null"], ParamsValue::Null)); - - if let ParamsValue::Array(arr) = &map["array"] { - assert!(matches!(arr[0], ParamsValue::Number(Number(N::PosInt(1))))); - assert!(matches!(arr[1], ParamsValue::Convertible(ref s) if s == "two")); - assert!(matches!(arr[2], ParamsValue::Convertible(ref s) if s == "false")); + assert!(matches!(map["string"], Value::XStr(ref s) if s == "hello")); + assert!(matches!(map["bool"], Value::XStr(ref s) if s == "true")); + assert!(matches!(map["null"], Value::Null)); + + if let Value::Array(arr) = &map["array"] { + assert!(matches!(arr[0], Value::Number(Number(N::PosInt(1))))); + assert!(matches!(arr[1], Value::XStr(ref s) if s == "two")); + assert!(matches!(arr[2], Value::XStr(ref s) if s == "false")); } else { panic!("Expected array"); } - if let ParamsValue::Object(nested) = &map["nested"] { - assert!(matches!( - nested["a"], - ParamsValue::Number(Number(N::PosInt(1))) - )); - assert!(matches!( - nested["b"], - ParamsValue::Number(Number(N::PosInt(2))) - )); + if let Value::Object(nested) = &map["nested"] { + assert!(matches!(nested["a"], Value::Number(Number(N::PosInt(1))))); + assert!(matches!(nested["b"], Value::Number(Number(N::PosInt(2))))); } else { panic!("Expected nested object"); }