diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index dc8f55b7c..91a0af1a5 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3904,6 +3904,7 @@ def url_schema( default_host: str | None = None, default_port: int | None = None, default_path: str | None = None, + preserve_empty_path: bool | None = None, strict: bool | None = None, ref: str | None = None, metadata: dict[str, Any] | None = None, @@ -3928,6 +3929,7 @@ def url_schema( default_host: The default host to use if the URL does not have a host default_port: The default port to use if the URL does not have a port default_path: The default path to use if the URL does not have a path + preserve_empty_path: Whether to preserve an empty path or convert it to '/', default False strict: Whether to use strict URL parsing ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -3941,6 +3943,7 @@ def url_schema( default_host=default_host, default_port=default_port, default_path=default_path, + preserve_empty_path=preserve_empty_path, strict=strict, ref=ref, metadata=metadata, @@ -3970,6 +3973,7 @@ def multi_host_url_schema( default_host: str | None = None, default_port: int | None = None, default_path: str | None = None, + preserve_empty_path: bool | None = None, strict: bool | None = None, ref: str | None = None, metadata: dict[str, Any] | None = None, @@ -3994,6 +3998,7 @@ def multi_host_url_schema( default_host: The default host to use if the URL does not have a host default_port: The default port to use if the URL does not have a port default_path: The default path to use if the URL does not have a path + preserve_empty_path: Whether to preserve an empty path or convert it to '/', default False strict: Whether to use strict URL parsing ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -4007,6 +4012,7 @@ def multi_host_url_schema( default_host=default_host, default_port=default_port, default_path=default_path, + preserve_empty_path=preserve_empty_path, strict=strict, ref=ref, metadata=metadata, diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index acd94c3d1..f40b827a1 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -197,11 +197,11 @@ pub(crate) fn infer_to_python_known( } ObType::Url => { let py_url: PyUrl = value.extract()?; - py_url.__str__().into_py_any(py)? + py_url.__str__(py).into_py_any(py)? } ObType::MultiHostUrl => { let py_url: PyMultiHostUrl = value.extract()?; - py_url.__str__().into_py_any(py)? + py_url.__str__(py).into_py_any(py)? } ObType::Uuid => { let uuid = super::type_serializers::uuid::uuid_to_string(value)?; @@ -476,11 +476,11 @@ pub(crate) fn infer_serialize_known( } ObType::Url => { let py_url: PyUrl = value.extract().map_err(py_err_se_err)?; - serializer.serialize_str(py_url.__str__()) + serializer.serialize_str(py_url.__str__(value.py())) } ObType::MultiHostUrl => { let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?; - serializer.serialize_str(&py_url.__str__()) + serializer.serialize_str(&py_url.__str__(value.py())) } ObType::PydanticSerializable => { let py = value.py(); @@ -644,11 +644,11 @@ pub(crate) fn infer_json_key_known<'a>( } ObType::Url => { let py_url: PyUrl = key.extract()?; - Ok(Cow::Owned(py_url.__str__().to_string())) + Ok(Cow::Owned(py_url.__str__(key.py()).to_string())) } ObType::MultiHostUrl => { let py_url: PyMultiHostUrl = key.extract()?; - Ok(Cow::Owned(py_url.__str__())) + Ok(Cow::Owned(py_url.__str__(key.py()).to_string())) } ObType::Tuple => { let mut key_build = super::type_serializers::tuple::KeyBuilder::new(); diff --git a/src/serializers/type_serializers/url.rs b/src/serializers/type_serializers/url.rs index 1e697d78b..76f13153d 100644 --- a/src/serializers/type_serializers/url.rs +++ b/src/serializers/type_serializers/url.rs @@ -43,7 +43,7 @@ macro_rules! build_serializer { let py = value.py(); match value.extract::<$extract>() { Ok(py_url) => match extra.mode { - SerMode::Json => py_url.__str__().into_py_any(py), + SerMode::Json => py_url.__str__(value.py()).into_py_any(py), _ => Ok(value.clone().unbind()), }, Err(_) => { @@ -55,7 +55,7 @@ macro_rules! build_serializer { fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { match key.extract::<$extract>() { - Ok(py_url) => Ok(Cow::Owned(py_url.__str__().to_string())), + Ok(py_url) => Ok(Cow::Owned(py_url.__str__(key.py()).to_string())), Err(_) => { extra.warnings.on_fallback_py(self.get_name(), key, extra)?; infer_json_key(key, extra) @@ -72,7 +72,7 @@ macro_rules! build_serializer { extra: &Extra, ) -> Result { match value.extract::<$extract>() { - Ok(py_url) => serializer.serialize_str(&py_url.__str__()), + Ok(py_url) => serializer.serialize_str(&py_url.__str__(value.py())), Err(_) => { extra .warnings diff --git a/src/url.rs b/src/url.rs index f824e9296..9429b0ebb 100644 --- a/src/url.rs +++ b/src/url.rs @@ -1,12 +1,14 @@ +use std::borrow::Cow; use std::collections::hash_map::DefaultHasher; use std::fmt; use std::fmt::Formatter; use std::hash::{Hash, Hasher}; +use std::sync::OnceLock; use idna::punycode::decode_to_string; use pyo3::exceptions::PyValueError; use pyo3::pyclass::CompareOp; -use pyo3::sync::GILOnceCell; +use pyo3::sync::{GILOnceCell, OnceLockExt}; use pyo3::types::{PyDict, PyType}; use pyo3::{intern, prelude::*, IntoPyObjectExt}; use url::Url; @@ -14,38 +16,71 @@ use url::Url; use crate::tools::SchemaDict; use crate::SchemaValidator; -static SCHEMA_DEFINITION_URL: GILOnceCell = GILOnceCell::new(); - #[pyclass(name = "Url", module = "pydantic_core._pydantic_core", subclass, frozen)] -#[derive(Clone, Hash)] +#[derive(Clone)] #[cfg_attr(debug_assertions, derive(Debug))] pub struct PyUrl { lib_url: Url, + /// Override to treat the path as empty when it is `/`. The `url` crate always normalizes an empty path to `/`, + /// but users may want to preserve the empty path when round-tripping. + path_is_empty: bool, + /// Cache for the serialized representation where this diverges from `lib_url.as_str()` + /// (i.e. when trailing slash was added to the empty path, but user didn't want that) + serialized: OnceLock, +} + +impl Hash for PyUrl { + fn hash(&self, state: &mut H) { + self.lib_url.hash(state); + self.path_is_empty.hash(state); + // no need to hash `serialized` as it's derived from the other two fields + } } impl PyUrl { - pub fn new(lib_url: Url) -> Self { - Self { lib_url } + pub fn new(lib_url: Url, path_is_empty: bool) -> Self { + Self { + lib_url, + path_is_empty, + serialized: OnceLock::new(), + } } pub fn url(&self) -> &Url { &self.lib_url } -} -fn build_schema_validator(py: Python, schema_type: &str) -> SchemaValidator { - let schema = PyDict::new(py); - schema.set_item("type", schema_type).unwrap(); - SchemaValidator::py_new(py, &schema, None).unwrap() + pub fn url_mut(&mut self) -> &mut Url { + &mut self.lib_url + } + + fn serialized(&self, py: Python<'_>) -> &str { + if self.path_is_empty { + self.serialized + .get_or_init_py_attached(py, || serialize_url_without_path_slash(&self.lib_url)) + } else { + self.lib_url.as_str() + } + } } #[pymethods] impl PyUrl { #[new] - pub fn py_new(py: Python, url: &Bound<'_, PyAny>) -> PyResult { - let schema_obj = SCHEMA_DEFINITION_URL - .get_or_init(py, || build_schema_validator(py, "url")) - .validate_python(py, url, None, None, None, None, None, false.into(), None, None)?; + #[pyo3(signature = (url, *, preserve_empty_path=false))] + pub fn py_new(py: Python, url: &Bound<'_, PyAny>, preserve_empty_path: bool) -> PyResult { + let schema_obj = get_schema_validator(py, false, preserve_empty_path)?.validate_python( + py, + url, + None, + None, + None, + None, + None, + false.into(), + None, + None, + )?; schema_obj.extract(py) } @@ -89,6 +124,7 @@ impl PyUrl { pub fn path(&self) -> Option<&str> { match self.lib_url.path() { "" => None, + "/" if self.path_is_empty => None, path => Some(path), } } @@ -113,16 +149,16 @@ impl PyUrl { } // string representation of the URL, with punycode decoded when appropriate - pub fn unicode_string(&self) -> String { - unicode_url(&self.lib_url) + pub fn unicode_string(&self, py: Python<'_>) -> Cow<'_, str> { + unicode_url(self.serialized(py), &self.lib_url) } - pub fn __str__(&self) -> &str { - self.lib_url.as_str() + pub fn __str__(&self, py: Python<'_>) -> &str { + self.serialized(py) } - pub fn __repr__(&self) -> String { - format!("Url('{}')", self.lib_url) + pub fn __repr__(&self, py: Python<'_>) -> String { + format!("Url('{}')", self.serialized(py)) } fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { @@ -151,8 +187,8 @@ impl PyUrl { self.clone().into_py_any(py) } - fn __getnewargs__(&self) -> (&str,) { - (self.__str__(),) + fn __getnewargs__(&self, py: Python<'_>) -> (&str,) { + (self.__str__(py),) } #[classmethod] @@ -201,11 +237,8 @@ pub struct PyMultiHostUrl { } impl PyMultiHostUrl { - pub fn new(ref_url: Url, extra_urls: Option>) -> Self { - Self { - ref_url: PyUrl::new(ref_url), - extra_urls, - } + pub fn new(ref_url: PyUrl, extra_urls: Option>) -> Self { + Self { ref_url, extra_urls } } pub fn lib_url(&self) -> &Url { @@ -217,15 +250,23 @@ impl PyMultiHostUrl { } } -static SCHEMA_DEFINITION_MULTI_HOST_URL: GILOnceCell = GILOnceCell::new(); - #[pymethods] impl PyMultiHostUrl { #[new] - pub fn py_new(py: Python, url: &Bound<'_, PyAny>) -> PyResult { - let schema_obj = SCHEMA_DEFINITION_MULTI_HOST_URL - .get_or_init(py, || build_schema_validator(py, "multi-host-url")) - .validate_python(py, url, None, None, None, None, None, false.into(), None, None)?; + #[pyo3(signature = (url, *, preserve_empty_path=false))] + pub fn py_new(py: Python, url: &Bound<'_, PyAny>, preserve_empty_path: bool) -> PyResult { + let schema_obj = get_schema_validator(py, true, preserve_empty_path)?.validate_python( + py, + url, + None, + None, + None, + None, + None, + false.into(), + None, + None, + )?; schema_obj.extract(py) } @@ -269,12 +310,12 @@ impl PyMultiHostUrl { } // string representation of the URL, with punycode decoded when appropriate - pub fn unicode_string(&self) -> String { + pub fn unicode_string(&self, py: Python<'_>) -> Cow<'_, str> { if let Some(extra_urls) = &self.extra_urls { let scheme = self.ref_url.lib_url.scheme(); let host_offset = scheme.len() + 3; - let mut full_url = self.ref_url.unicode_string(); + let mut full_url = self.ref_url.unicode_string(py).into_owned(); full_url.insert(host_offset, ','); // special urls will have had a trailing slash added, non-special urls will not @@ -285,24 +326,24 @@ impl PyMultiHostUrl { let hosts = extra_urls .iter() .map(|url| { - let str = unicode_url(url); + let str = unicode_url(url.as_str(), url); str[host_offset..str.len() - sub].to_string() }) .collect::>() .join(","); full_url.insert_str(host_offset, &hosts); - full_url + Cow::Owned(full_url) } else { - self.ref_url.unicode_string() + self.ref_url.unicode_string(py) } } - pub fn __str__(&self) -> String { + pub fn __str__(&self, py: Python<'_>) -> String { if let Some(extra_urls) = &self.extra_urls { let scheme = self.ref_url.lib_url.scheme(); let host_offset = scheme.len() + 3; - let mut full_url = self.ref_url.lib_url.to_string(); + let mut full_url = self.ref_url.serialized(py).to_string(); full_url.insert(host_offset, ','); // special urls will have had a trailing slash added, non-special urls will not @@ -321,22 +362,22 @@ impl PyMultiHostUrl { full_url.insert_str(host_offset, &hosts); full_url } else { - self.ref_url.__str__().to_string() + self.ref_url.__str__(py).to_string() } } - pub fn __repr__(&self) -> String { - format!("MultiHostUrl('{}')", self.__str__()) + pub fn __repr__(&self, py: Python<'_>) -> String { + format!("MultiHostUrl('{}')", self.__str__(py)) } - fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult { match op { - CompareOp::Lt => Ok(self.unicode_string() < other.unicode_string()), - CompareOp::Le => Ok(self.unicode_string() <= other.unicode_string()), - CompareOp::Eq => Ok(self.unicode_string() == other.unicode_string()), - CompareOp::Ne => Ok(self.unicode_string() != other.unicode_string()), - CompareOp::Gt => Ok(self.unicode_string() > other.unicode_string()), - CompareOp::Ge => Ok(self.unicode_string() >= other.unicode_string()), + CompareOp::Lt => Ok(self.unicode_string(py) < other.unicode_string(py)), + CompareOp::Le => Ok(self.unicode_string(py) <= other.unicode_string(py)), + CompareOp::Eq => Ok(self.unicode_string(py) == other.unicode_string(py)), + CompareOp::Ne => Ok(self.unicode_string(py) != other.unicode_string(py)), + CompareOp::Gt => Ok(self.unicode_string(py) > other.unicode_string(py)), + CompareOp::Ge => Ok(self.unicode_string(py) >= other.unicode_string(py)), } } @@ -354,8 +395,8 @@ impl PyMultiHostUrl { self.clone().into_py_any(py) } - fn __getnewargs__(&self) -> (String,) { - (self.__str__(),) + fn __getnewargs__(&self, py: Python<'_>) -> (String,) { + (self.__str__(py),) } #[classmethod] @@ -477,19 +518,18 @@ fn host_to_dict<'a>(py: Python<'a>, lib_url: &Url) -> PyResult Ok(dict) } -fn unicode_url(lib_url: &Url) -> String { - let mut s = lib_url.to_string(); - +fn unicode_url<'s>(serialized: &'s str, lib_url: &Url) -> Cow<'s, str> { match lib_url.host() { Some(url::Host::Domain(domain)) if is_punnycode_domain(lib_url, domain) => { + let mut s = serialized.to_string(); if let Some(decoded) = decode_punycode(domain) { // replace the range containing the punycode domain with the decoded domain let start = lib_url.scheme().len() + 3; s.replace_range(start..start + domain.len(), &decoded); } - s + Cow::Owned(s) } - _ => s, + _ => Cow::Borrowed(serialized), } } @@ -517,3 +557,53 @@ fn is_punnycode_domain(lib_url: &Url, domain: &str) -> bool { pub fn scheme_is_special(scheme: &str) -> bool { matches!(scheme, "http" | "https" | "ws" | "wss" | "ftp" | "file") } + +fn serialize_url_without_path_slash(url: &Url) -> String { + // use pointer arithmetic to find the pieces we need to build the string + let s = url.as_str(); + let path = url.path(); + assert_eq!(path, "/", "`path_is_empty` expected to be set only when path is '/'"); + + assert!( + // Safety for the below: `s` and `path` should be from the same text slice, so + // we can pull out the slices of `s` that don't include `path`. + s.as_ptr() <= path.as_ptr() && unsafe { s.as_ptr().add(s.len()) } >= unsafe { path.as_ptr().add(path.len()) } + ); + + let prefix_len = path.as_ptr() as usize - s.as_ptr() as usize; + let suffix_len = s.len() - (prefix_len + path.len()); + + // Safety: prefix is the slice of `s` leading to `path`, protected by the assert above. + let prefix = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(s.as_ptr(), prefix_len)) }; + // Safety: suffix is the slice of `s` after `path`, protected by the assert above. + let suffix = + unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(path.as_ptr().add(path.len()), suffix_len)) }; + + format!("{prefix}{suffix}") +} + +static SCHEMA_URL_SINGLE_TRUE: GILOnceCell = GILOnceCell::new(); +static SCHEMA_URL_SINGLE_FALSE: GILOnceCell = GILOnceCell::new(); +static SCHEMA_URL_MULTI_TRUE: GILOnceCell = GILOnceCell::new(); +static SCHEMA_URL_MULTI_FALSE: GILOnceCell = GILOnceCell::new(); + +macro_rules! make_schema_val { + ($py:ident, $schema_type:literal, $preserve_empty_path:literal) => {{ + let schema = PyDict::new($py); + schema.set_item(intern!($py, "type"), intern!($py, $schema_type))?; + // preserve_empty_path defaults to false, so only set it if true + if $preserve_empty_path { + schema.set_item(intern!($py, "preserve_empty_path"), true)?; + } + SchemaValidator::py_new($py, &schema, None) + }}; +} + +fn get_schema_validator(py: Python<'_>, multi_host: bool, preserve_empty_path: bool) -> PyResult<&SchemaValidator> { + match (multi_host, preserve_empty_path) { + (false, true) => SCHEMA_URL_SINGLE_TRUE.get_or_try_init(py, || make_schema_val!(py, "url", true)), + (false, false) => SCHEMA_URL_SINGLE_FALSE.get_or_try_init(py, || make_schema_val!(py, "url", false)), + (true, true) => SCHEMA_URL_MULTI_TRUE.get_or_try_init(py, || make_schema_val!(py, "multi-host-url", true)), + (true, false) => SCHEMA_URL_MULTI_FALSE.get_or_try_init(py, || make_schema_val!(py, "multi-host-url", false)), + } +} diff --git a/src/validators/url.rs b/src/validators/url.rs index fd11137b4..11bcac1a1 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::cell::RefCell; use std::iter::Peekable; use std::str::Chars; @@ -10,8 +11,8 @@ use ahash::AHashSet; use pyo3::IntoPyObjectExt; use url::{ParseError, SyntaxViolation, Url}; +use crate::build_tools::schema_or_config; use crate::build_tools::{is_strict, py_schema_err}; -use crate::errors::ToErrorValue; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::downcast_python_input; use crate::input::Input; @@ -35,6 +36,17 @@ pub struct UrlValidator { default_port: Option, default_path: Option, name: String, + preserve_empty_path: bool, +} + +fn get_preserve_empty_path(schema: &Bound<'_, PyDict>, config: Option<&Bound<'_, PyDict>>) -> PyResult { + schema_or_config( + schema, + config, + intern!(schema.py(), "preserve_empty_path"), + intern!(schema.py(), "url_preserve_empty_path"), + ) + .map(|v| v.unwrap_or(false)) } impl BuildValidator for UrlValidator { @@ -56,6 +68,7 @@ impl BuildValidator for UrlValidator { default_path: schema.get_as(intern!(schema.py(), "default_path"))?, allowed_schemes, name, + preserve_empty_path: get_preserve_empty_path(schema, config)?, } .into()) } @@ -70,7 +83,7 @@ impl Validator for UrlValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let mut either_url = self.get_url(input, state.strict_or(self.strict))?; + let mut either_url = self.get_url(py, input, state.strict_or(self.strict))?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { if !allowed_schemes.contains(either_url.url().scheme()) { @@ -107,25 +120,34 @@ impl Validator for UrlValidator { } impl UrlValidator { - fn get_url<'py>(&self, input: &(impl Input<'py> + ?Sized), strict: bool) -> ValResult> { + fn get_url<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + strict: bool, + ) -> ValResult> { if let Some(py_url) = downcast_python_input::(input) { // we don't need to worry about whether the url was parsed in strict mode before, // even if it was, any syntax errors would have been fixed by the first validation - self.check_length(input, py_url.get().url().as_str())?; - Ok(EitherUrl::Py(py_url.clone())) - } else if let Some(multi_host_url) = downcast_python_input::(input) { - let url_str = multi_host_url.get().__str__(); - self.check_length(input, &url_str)?; - parse_url(&url_str, input, strict).map(EitherUrl::Rust) - } else if let Ok(either_str) = input.validate_str(strict, false).map(ValidationMatch::into_inner) { - let cow = either_str.as_cow()?; - let url_str = cow.as_ref(); + self.check_length(input, py_url.get().__str__(py))?; + return Ok(EitherUrl::Py(py_url.clone())); + } - self.check_length(input, url_str)?; - parse_url(url_str, input, strict).map(EitherUrl::Rust) + let either_str_owned; + let url_str = if let Some(multi_host_url) = downcast_python_input::(input) { + Cow::Owned(multi_host_url.get().__str__(py)) + } else if let Ok(either_str) = input.validate_str(strict, false).map(ValidationMatch::into_inner) { + either_str_owned = either_str; // to extend the lifetime outside the if let + either_str_owned.as_cow()? } else { - Err(ValError::new(ErrorTypeDefaults::UrlType, input)) - } + return Err(ValError::new(ErrorTypeDefaults::UrlType, input)); + }; + + let url_str = url_str.as_ref(); + self.check_length(input, url_str)?; + let url = parse_url(url_str, input, strict)?; + let path_is_empty = need_to_preserve_empty_path(&url, url_str, self.preserve_empty_path); + Ok(EitherUrl::Rust(PyUrl::new(url, path_is_empty))) } fn check_length<'py>(&self, input: &(impl Input<'py> + ?Sized), url_str: &str) -> ValResult<()> { @@ -146,7 +168,7 @@ impl UrlValidator { enum EitherUrl<'py> { Py(Bound<'py, PyUrl>), - Rust(Url), + Rust(PyUrl), } impl<'py> IntoPyObject<'py> for EitherUrl<'py> { @@ -157,7 +179,7 @@ impl<'py> IntoPyObject<'py> for EitherUrl<'py> { fn into_pyobject(self, py: Python<'py>) -> PyResult { match self { EitherUrl::Py(py_url) => Ok(py_url), - EitherUrl::Rust(rust_url) => Bound::new(py, PyUrl::new(rust_url)), + EitherUrl::Rust(rust_url) => Bound::new(py, rust_url), } } } @@ -166,17 +188,17 @@ impl CopyFromPyUrl for EitherUrl<'_> { fn url(&self) -> &Url { match self { EitherUrl::Py(py_url) => py_url.get().url(), - EitherUrl::Rust(rust_url) => rust_url, + EitherUrl::Rust(rust_url) => rust_url.url(), } } fn url_mut(&mut self) -> &mut Url { if let EitherUrl::Py(py_url) = self { - *self = EitherUrl::Rust(py_url.get().url().clone()); + *self = EitherUrl::Rust(py_url.get().clone()); } match self { EitherUrl::Py(_) => unreachable!(), - EitherUrl::Rust(rust_url) => rust_url, + EitherUrl::Rust(rust_url) => rust_url.url_mut(), } } } @@ -191,6 +213,7 @@ pub struct MultiHostUrlValidator { default_port: Option, default_path: Option, name: String, + preserve_empty_path: bool, } impl BuildValidator for MultiHostUrlValidator { @@ -218,6 +241,7 @@ impl BuildValidator for MultiHostUrlValidator { default_port: schema.get_as(intern!(schema.py(), "default_port"))?, default_path: schema.get_as(intern!(schema.py(), "default_path"))?, name, + preserve_empty_path: get_preserve_empty_path(schema, config)?, } .into()) } @@ -232,7 +256,7 @@ impl Validator for MultiHostUrlValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let mut multi_url = self.get_url(input, state.strict_or(self.strict))?; + let mut multi_url = self.get_url(py, input, state.strict_or(self.strict))?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { if !allowed_schemes.contains(multi_url.url().scheme()) { @@ -268,16 +292,21 @@ impl Validator for MultiHostUrlValidator { } impl MultiHostUrlValidator { - fn get_url<'py>(&self, input: &(impl Input<'py> + ?Sized), strict: bool) -> ValResult> { + fn get_url<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + strict: bool, + ) -> ValResult> { // we don't need to worry about whether the url was parsed in strict mode before, // even if it was, any syntax errors would have been fixed by the first validation if let Some(multi_url) = downcast_python_input::(input) { - self.check_length(input, || multi_url.get().__str__().len())?; + self.check_length(input, || multi_url.get().__str__(py).len())?; Ok(EitherMultiHostUrl::Py(multi_url.clone())) } else if let Some(py_url) = downcast_python_input::(input) { - self.check_length(input, || py_url.get().url().as_str().len())?; + self.check_length(input, || py_url.get().__str__(py).len())?; Ok(EitherMultiHostUrl::Rust(PyMultiHostUrl::new( - py_url.get().url().clone(), + py_url.get().clone(), None, ))) } else if let Ok(either_str) = input.validate_str(strict, false).map(ValidationMatch::into_inner) { @@ -286,7 +315,7 @@ impl MultiHostUrlValidator { self.check_length(input, || url_str.len())?; - parse_multihost_url(url_str, input, strict).map(EitherMultiHostUrl::Rust) + parse_multihost_url(url_str, input, strict, self.preserve_empty_path).map(EitherMultiHostUrl::Rust) } else { Err(ValError::new(ErrorTypeDefaults::UrlType, input)) } @@ -352,6 +381,7 @@ fn parse_multihost_url<'py>( url_str: &str, input: &(impl Input<'py> + ?Sized), strict: bool, + preserve_empty_path: bool, ) -> ValResult { macro_rules! parsing_err { ($parse_error:expr) => { @@ -442,13 +472,16 @@ fn parse_multihost_url<'py>( let reconstructed_url = format!("{prefix}{}", &url_str[start..]); let ref_url = parse_url(&reconstructed_url, input, strict)?; + let path_is_empty = need_to_preserve_empty_path(&ref_url, &reconstructed_url, preserve_empty_path); + + let ref_url = PyUrl::new(ref_url, path_is_empty); if hosts.is_empty() { // if there's no one host (e.g. no `,`), we allow it to be empty to allow for default hosts Ok(PyMultiHostUrl::new(ref_url, None)) } else { // with more than one host, none of them can be empty - if !ref_url.has_host() { + if !ref_url.url().has_host() { return parsing_err!(ParseError::EmptyHost); } let extra_urls: Vec = hosts @@ -467,7 +500,7 @@ fn parse_multihost_url<'py>( } } -fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool) -> ValResult { +fn parse_url<'py>(url_str: &str, input: &(impl Input<'py> + ?Sized), strict: bool) -> ValResult { if url_str.is_empty() { return Err(ValError::new( ErrorType::UrlParsing { @@ -478,45 +511,21 @@ fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool) -> ValResult )); } - // if we're in strict mode, we collect consider a syntax violation as an error - if strict { - // we could build a vec of syntax violations and return them all, but that seems like overkill - // and unlike other parser style validators - let vios: RefCell> = RefCell::new(None); - let r = Url::options() - .syntax_violation_callback(Some(&|v| { - match v { - // telling users offer about credentials in URLs doesn't really make sense in this context - SyntaxViolation::EmbeddedCredentials => (), - _ => *vios.borrow_mut() = Some(v), - } - })) - .parse(url_str); - - match r { - Ok(url) => { - if let Some(vio) = vios.into_inner() { - Err(ValError::new( - ErrorType::UrlSyntaxViolation { - error: vio.description().into(), - context: None, - }, - input, - )) - } else { - Ok(url) - } + // we could build a vec of syntax violations and return them all, but that seems like overkill + // and unlike other parser style validators + let vios = RefCell::new(None); + + let url = Url::options() + // if we're in strict mode, we collect consider a syntax violation as an error + .syntax_violation_callback(strict.then_some(&|v| { + match v { + // telling users offer about credentials in URLs doesn't really make sense in this context + SyntaxViolation::EmbeddedCredentials => (), + _ => *vios.borrow_mut() = Some(v), } - Err(e) => Err(ValError::new( - ErrorType::UrlParsing { - error: e.to_string(), - context: None, - }, - input, - )), - } - } else { - Url::parse(url_str).map_err(move |e| { + })) + .parse(url_str) + .map_err(|e| { ValError::new( ErrorType::UrlParsing { error: e.to_string(), @@ -524,8 +533,56 @@ fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool) -> ValResult }, input, ) - }) + })?; + + if let Some(vio) = vios.into_inner() { + return Err(ValError::new( + ErrorType::UrlSyntaxViolation { + error: vio.description().into(), + context: None, + }, + input, + )); + } + + Ok(url) +} + +/// Check if the path got normalized to `/` and the original string had an empty path +fn need_to_preserve_empty_path(url: &Url, url_str: &str, preserve_empty_path: bool) -> bool { + if !preserve_empty_path { + return false; } + + if url.path() != "/" { + // was definitely not the case + return false; + } + + if !scheme_is_special(url.scheme()) { + // non-special schemes don't normalize the path + return false; + } + + // find the scheme marker in the original input + let (_, input_without_scheme) = url_str.split_once(':').expect("url has a scheme"); + + // strip any leading / (which would be part of the authority marker), URL will normalize any + // number of them even if there should only be two + let input_without_scheme = input_without_scheme.trim_start_matches('/'); + + // Now find the start of the path, which is either the first /, ?, or #, or the end of the + // string + for c in input_without_scheme.chars() { + match c { + '/' => return false, // found the start of the path, and it's not empty + '?' | '#' => return true, // found the start of the query or fragment, so path is empty + _ => (), + } + } + + // reached the end of the string without finding a path, so it's empty + true } /// check host_required and substitute `default_host`, `default_port` & `default_path` if they aren't set diff --git a/tests/validators/test_url.py b/tests/validators/test_url.py index 6ebaa90f6..1b1107f43 100644 --- a/tests/validators/test_url.py +++ b/tests/validators/test_url.py @@ -277,6 +277,79 @@ def test_url_cases(url_validator, url, expected, mode): url_test_case_helper(url, expected, mode, url_validator) +@pytest.mark.parametrize( + ('url', 'expected', 'expected_path'), + [ + ('http://example.com', 'http://example.com', None), + ('http:example.com', 'http://example.com', None), + ('http:/example.com', 'http://example.com', None), + ('http://example.com/', 'http://example.com/', '/'), + ('http:example.com/', 'http://example.com/', '/'), + ('http:/example.com/', 'http://example.com/', '/'), + ('http://example.com?x=1', 'http://example.com?x=1', None), + ('http://example.com/?x=1', 'http://example.com/?x=1', '/'), + ('http://example.com#foo', 'http://example.com#foo', None), + ('http://example.com/#foo', 'http://example.com/#foo', '/'), + ('http://example.com/path', 'http://example.com/path', '/path'), + ('http://example.com/path/', 'http://example.com/path/', '/path/'), + ('http://example.com/path?x=1', 'http://example.com/path?x=1', '/path'), + ('http://example.com/path/?x=1', 'http://example.com/path/?x=1', '/path/'), + ], +) +def test_trailing_slash(url: str, expected: str, expected_path: Optional[str]): + url1 = Url(url, preserve_empty_path=True) + assert str(url1) == expected + assert url1.unicode_string() == expected + assert url1.path == expected_path + + v = SchemaValidator(core_schema.url_schema(preserve_empty_path=True)) + url2 = v.validate_python(url) + assert str(url2) == expected + assert url2.unicode_string() == expected + assert url2.path == expected_path + + v = SchemaValidator(core_schema.url_schema(), CoreConfig(url_preserve_empty_path=True)) + url3 = v.validate_python(url) + assert str(url3) == expected + assert url3.unicode_string() == expected + assert url3.path == expected_path + + +@pytest.mark.parametrize( + ('url', 'expected', 'expected_path'), + [ + ('http://example.com', 'http://example.com', None), + ('http://example.com/', 'http://example.com/', '/'), + ('http://example.com/path', 'http://example.com/path', '/path'), + ('http://example.com/path/', 'http://example.com/path/', '/path/'), + ('http://example.com,example.org', 'http://example.com,example.org', None), + ('http://example.com,example.org/', 'http://example.com,example.org/', '/'), + ('http://localhost,127.0.0.1', 'http://localhost,127.0.0.1', None), + ('http://localhost,127.0.0.1/', 'http://localhost,127.0.0.1/', '/'), + ('http:localhost,127.0.0.1', 'http://localhost,127.0.0.1', None), + ('http://localhost,127.0.0.1/path', 'http://localhost,127.0.0.1/path', '/path'), + ('http://localhost,127.0.0.1/path/', 'http://localhost,127.0.0.1/path/', '/path/'), + ], +) +def test_multi_trailing_slash(url: str, expected: str, expected_path: Optional[str]): + url1 = MultiHostUrl(url, preserve_empty_path=True) + assert str(url1) == expected + assert url1.unicode_string() == expected + assert url1.path == expected_path + + v = SchemaValidator(core_schema.multi_host_url_schema(preserve_empty_path=True)) + url2 = v.validate_python(url) + assert str(url2) == expected + assert url2.unicode_string() == expected + assert url2.path == expected_path + + v = SchemaValidator(core_schema.multi_host_url_schema(), CoreConfig(url_preserve_empty_path=True)) + url3 = v.validate_python(url) + assert str(url3) == expected + assert url3.unicode_string() == expected + assert url3.path == expected_path + + @pytest.mark.parametrize( 'validator_kwargs,url,expected', [