From a6302c818feab7e096eeda564a25b33fd72cc992 Mon Sep 17 00:00:00 2001 From: hrmny <8845940+ForsakenHarmony@users.noreply.github.com> Date: Thu, 19 Sep 2024 18:00:01 +0200 Subject: [PATCH] feat: add support for preserving characters when decoding --- percent_encoding/src/lib.rs | 129 ++++++++++++++++++++++++++++++++---- 1 file changed, 116 insertions(+), 13 deletions(-) diff --git a/percent_encoding/src/lib.rs b/percent_encoding/src/lib.rs index 2213943b..693bbc26 100644 --- a/percent_encoding/src/lib.rs +++ b/percent_encoding/src/lib.rs @@ -66,7 +66,7 @@ use core::{fmt, mem, ops, slice, str}; /// /// https://url.spec.whatwg.org/#fragment-percent-encode-set /// const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); /// ``` -#[derive(Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct AsciiSet { mask: [Chunk; ASCII_RANGE_LEN / BITS_PER_CHUNK], } @@ -79,7 +79,7 @@ const BITS_PER_CHUNK: usize = 8 * mem::size_of::(); impl AsciiSet { /// An empty set. - pub const EMPTY: AsciiSet = AsciiSet { + pub const EMPTY: &'static AsciiSet = &AsciiSet { mask: [0; ASCII_RANGE_LEN / BITS_PER_CHUNK], }; @@ -101,6 +101,18 @@ impl AsciiSet { AsciiSet { mask } } + pub const fn add_range(&self, start: u8, end: u8) -> Self { + let mut new = AsciiSet { mask: self.mask }; + + let mut i = start; + while i <= end { + new = new.add(i); + i += 1; + } + + new + } + pub const fn remove(&self, byte: u8) -> Self { let mut mask = self.mask; mask[byte as usize / BITS_PER_CHUNK] &= !(1 << (byte as usize % BITS_PER_CHUNK)); @@ -108,7 +120,7 @@ impl AsciiSet { } /// Return the union of two sets. - pub const fn union(&self, other: Self) -> Self { + pub const fn union(&self, other: &Self) -> Self { let mask = [ self.mask[0] | other.mask[0], self.mask[1] | other.mask[1], @@ -128,7 +140,15 @@ impl AsciiSet { impl ops::Add for AsciiSet { type Output = Self; - fn add(self, other: Self) -> Self { + fn add(self, other: Self) -> Self::Output { + self.union(&other) + } +} + +impl ops::Add for &AsciiSet { + type Output = AsciiSet; + + fn add(self, other: Self) -> Self::Output { self.union(other) } } @@ -136,7 +156,15 @@ impl ops::Add for AsciiSet { impl ops::Not for AsciiSet { type Output = Self; - fn not(self) -> Self { + fn not(self) -> Self::Output { + self.complement() + } +} + +impl ops::Not for &AsciiSet { + type Output = AsciiSet; + + fn not(self) -> Self::Output { self.complement() } } @@ -268,7 +296,7 @@ pub fn percent_encode_byte(byte: u8) -> &'static str { /// assert_eq!(percent_encode(b"foo bar?", NON_ALPHANUMERIC).to_string(), "foo%20bar%3F"); /// ``` #[inline] -pub fn percent_encode<'a>(input: &'a [u8], ascii_set: &'static AsciiSet) -> PercentEncode<'a> { +pub fn percent_encode<'a>(input: &'a [u8], ascii_set: &'a AsciiSet) -> PercentEncode<'a> { PercentEncode { bytes: input, ascii_set, @@ -287,7 +315,7 @@ pub fn percent_encode<'a>(input: &'a [u8], ascii_set: &'static AsciiSet) -> Perc /// assert_eq!(utf8_percent_encode("foo bar?", NON_ALPHANUMERIC).to_string(), "foo%20bar%3F"); /// ``` #[inline] -pub fn utf8_percent_encode<'a>(input: &'a str, ascii_set: &'static AsciiSet) -> PercentEncode<'a> { +pub fn utf8_percent_encode<'a>(input: &'a str, ascii_set: &'a AsciiSet) -> PercentEncode<'a> { percent_encode(input.as_bytes(), ascii_set) } @@ -295,7 +323,7 @@ pub fn utf8_percent_encode<'a>(input: &'a str, ascii_set: &'static AsciiSet) -> #[derive(Clone)] pub struct PercentEncode<'a> { bytes: &'a [u8], - ascii_set: &'static AsciiSet, + ascii_set: &'a AsciiSet, } impl<'a> Iterator for PercentEncode<'a> { @@ -372,6 +400,19 @@ pub fn percent_decode_str(input: &str) -> PercentDecode<'_> { percent_decode(input.as_bytes()) } +/// Percent-decode the given string preserving the given ascii_set. +/// +/// +/// +/// See [`percent_decode`] regarding the return type. +#[inline] +pub fn percent_decode_str_with_set<'a>( + input: &'a str, + ascii_set: &'a AsciiSet, +) -> PercentDecode<'a> { + percent_decode_with_set(input.as_bytes(), ascii_set) +} + /// Percent-decode the given bytes. /// /// @@ -394,6 +435,36 @@ pub fn percent_decode_str(input: &str) -> PercentDecode<'_> { pub fn percent_decode(input: &[u8]) -> PercentDecode<'_> { PercentDecode { bytes: input.iter(), + ascii_set: None, + } +} + +/// Percent-decode the given bytes preserving the given ascii_set. +/// +/// +/// +/// Any sequence of `%` followed by two hexadecimal digits expect for the given [AsciiSet] is decoded. +/// The return type: +/// +/// * Implements `Into>` borrowing `input` when it contains no percent-encoded sequence, +/// * Implements `Iterator` and therefore has a `.collect::>()` method, +/// * Has `decode_utf8()` and `decode_utf8_lossy()` methods. +/// +/// # Examples +/// +/// ``` +/// use percent_encoding::{percent_decode_with_set, NON_ALPHANUMERIC}; +/// +/// assert_eq!(percent_decode_with_set(b"%66oo%20bar%3f", &!NON_ALPHANUMERIC).decode_utf8().unwrap(), "%66oo bar?"); +/// ``` +#[inline] +pub fn percent_decode_with_set<'a>( + input: &'a [u8], + ascii_set: &'a AsciiSet, +) -> PercentDecode<'a> { + PercentDecode { + bytes: input.iter(), + ascii_set: Some(ascii_set), } } @@ -401,6 +472,7 @@ pub fn percent_decode(input: &[u8]) -> PercentDecode<'_> { #[derive(Clone, Debug)] pub struct PercentDecode<'a> { bytes: slice::Iter<'a, u8>, + ascii_set: Option<&'a AsciiSet>, } fn after_percent_sign(iter: &mut slice::Iter<'_, u8>) -> Option { @@ -411,13 +483,35 @@ fn after_percent_sign(iter: &mut slice::Iter<'_, u8>) -> Option { Some(h as u8 * 0x10 + l as u8) } +fn after_percent_sign_lookahead<'a>( + iter: &mut slice::Iter<'a, u8>, +) -> Option<(u8, slice::Iter<'a, u8>)> { + let mut cloned_iter = iter.clone(); + let h = char::from(*cloned_iter.next()?).to_digit(16)?; + let l = char::from(*cloned_iter.next()?).to_digit(16)?; + Some((h as u8 * 0x10 + l as u8, cloned_iter)) +} + impl<'a> Iterator for PercentDecode<'a> { type Item = u8; fn next(&mut self) -> Option { self.bytes.next().map(|&byte| { - if byte == b'%' { - after_percent_sign(&mut self.bytes).unwrap_or(byte) + if byte != b'%' { + return byte; + } + + let Some((decoded_byte, iter)) = after_percent_sign_lookahead(&mut self.bytes) else { + return byte; + }; + + let should_decode = self + .ascii_set + .map_or(true, |ascii_set| !ascii_set.contains(decoded_byte)); + + if should_decode { + self.bytes = iter; + decoded_byte } else { byte } @@ -447,11 +541,20 @@ impl<'a> PercentDecode<'a> { let mut bytes_iter = self.bytes.clone(); while bytes_iter.any(|&b| b == b'%') { if let Some(decoded_byte) = after_percent_sign(&mut bytes_iter) { + if let Some(ascii_set) = self.ascii_set { + if ascii_set.contains(decoded_byte) { + continue; + } + } + let initial_bytes = self.bytes.as_slice(); let unchanged_bytes_len = initial_bytes.len() - bytes_iter.len() - 3; let mut decoded = initial_bytes[..unchanged_bytes_len].to_owned(); decoded.push(decoded_byte); - decoded.extend(PercentDecode { bytes: bytes_iter }); + decoded.extend(PercentDecode { + bytes: bytes_iter, + ascii_set: self.ascii_set, + }); return Some(decoded); } } @@ -542,8 +645,8 @@ mod tests { /// useful for defining sets in a modular way. #[test] fn union() { - const A: AsciiSet = AsciiSet::EMPTY.add(b'A'); - const B: AsciiSet = AsciiSet::EMPTY.add(b'B'); + const A: &AsciiSet = &AsciiSet::EMPTY.add(b'A'); + const B: &AsciiSet = &AsciiSet::EMPTY.add(b'B'); const UNION: AsciiSet = A.union(B); const EXPECTED: AsciiSet = AsciiSet::EMPTY.add(b'A').add(b'B'); assert_eq!(UNION, EXPECTED);