From 710e1e7da0aa68e2f1596fe48768a88acff07736 Mon Sep 17 00:00:00 2001 From: Josh McKinney Date: Thu, 26 Sep 2024 17:13:20 -0700 Subject: [PATCH] refactor and add tests for percent encoding (#977) * refactor: move AsciiSet related code to ascii_set.rs Makes it a little easier to just look at the code related to the set rather than encoding * Add tests for percent_encoding --- percent_encoding/src/ascii_set.rs | 213 ++++++++++++++++++++++ percent_encoding/src/lib.rs | 287 +++++++++++------------------- 2 files changed, 316 insertions(+), 184 deletions(-) create mode 100644 percent_encoding/src/ascii_set.rs diff --git a/percent_encoding/src/ascii_set.rs b/percent_encoding/src/ascii_set.rs new file mode 100644 index 00000000..41cd235e --- /dev/null +++ b/percent_encoding/src/ascii_set.rs @@ -0,0 +1,213 @@ +// Copyright 2013-2016 The rust-url developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::{mem, ops}; + +/// Represents a set of characters or bytes in the ASCII range. +/// +/// This is used in [`percent_encode`] and [`utf8_percent_encode`]. +/// This is similar to [percent-encode sets](https://url.spec.whatwg.org/#percent-encoded-bytes). +/// +/// Use the `add` method of an existing set to define a new set. For example: +/// +/// [`percent_encode`]: crate::percent_encode +/// [`utf8_percent_encode`]: crate::utf8_percent_encode +/// +/// ``` +/// use percent_encoding::{AsciiSet, CONTROLS}; +/// +/// /// https://url.spec.whatwg.org/#fragment-percent-encode-set +/// const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); +/// ``` +#[derive(Debug, PartialEq, Eq)] +pub struct AsciiSet { + mask: [Chunk; ASCII_RANGE_LEN / BITS_PER_CHUNK], +} + +type Chunk = u32; + +const ASCII_RANGE_LEN: usize = 0x80; + +const BITS_PER_CHUNK: usize = 8 * mem::size_of::(); + +impl AsciiSet { + /// An empty set. + pub const EMPTY: AsciiSet = AsciiSet { + mask: [0; ASCII_RANGE_LEN / BITS_PER_CHUNK], + }; + + /// Called with UTF-8 bytes rather than code points. + /// Not used for non-ASCII bytes. + pub(crate) const fn contains(&self, byte: u8) -> bool { + let chunk = self.mask[byte as usize / BITS_PER_CHUNK]; + let mask = 1 << (byte as usize % BITS_PER_CHUNK); + (chunk & mask) != 0 + } + + pub(crate) fn should_percent_encode(&self, byte: u8) -> bool { + !byte.is_ascii() || self.contains(byte) + } + + pub const fn add(&self, byte: u8) -> Self { + let mut mask = self.mask; + mask[byte as usize / BITS_PER_CHUNK] |= 1 << (byte as usize % BITS_PER_CHUNK); + AsciiSet { mask } + } + + pub const fn remove(&self, byte: u8) -> Self { + let mut mask = self.mask; + mask[byte as usize / BITS_PER_CHUNK] &= !(1 << (byte as usize % BITS_PER_CHUNK)); + AsciiSet { mask } + } + + /// Return the union of two sets. + pub const fn union(&self, other: Self) -> Self { + let mask = [ + self.mask[0] | other.mask[0], + self.mask[1] | other.mask[1], + self.mask[2] | other.mask[2], + self.mask[3] | other.mask[3], + ]; + AsciiSet { mask } + } + + /// Return the negation of the set. + pub const fn complement(&self) -> Self { + let mask = [!self.mask[0], !self.mask[1], !self.mask[2], !self.mask[3]]; + AsciiSet { mask } + } +} + +impl ops::Add for AsciiSet { + type Output = Self; + + fn add(self, other: Self) -> Self { + self.union(other) + } +} + +impl ops::Not for AsciiSet { + type Output = Self; + + fn not(self) -> Self { + self.complement() + } +} + +/// The set of 0x00 to 0x1F (C0 controls), and 0x7F (DEL). +/// +/// Note that this includes the newline and tab characters, but not the space 0x20. +/// +/// +pub const CONTROLS: &AsciiSet = &AsciiSet { + mask: [ + !0_u32, // C0: 0x00 to 0x1F (32 bits set) + 0, + 0, + 1 << (0x7F_u32 % 32), // DEL: 0x7F (one bit set) + ], +}; + +macro_rules! static_assert { + ($( $bool: expr, )+) => { + fn _static_assert() { + $( + let _ = mem::transmute::<[u8; $bool as usize], u8>; + )+ + } + } +} + +static_assert! { + CONTROLS.contains(0x00), + CONTROLS.contains(0x1F), + !CONTROLS.contains(0x20), + !CONTROLS.contains(0x7E), + CONTROLS.contains(0x7F), +} + +/// Everything that is not an ASCII letter or digit. +/// +/// This is probably more eager than necessary in any context. +pub const NON_ALPHANUMERIC: &AsciiSet = &CONTROLS + .add(b' ') + .add(b'!') + .add(b'"') + .add(b'#') + .add(b'$') + .add(b'%') + .add(b'&') + .add(b'\'') + .add(b'(') + .add(b')') + .add(b'*') + .add(b'+') + .add(b',') + .add(b'-') + .add(b'.') + .add(b'/') + .add(b':') + .add(b';') + .add(b'<') + .add(b'=') + .add(b'>') + .add(b'?') + .add(b'@') + .add(b'[') + .add(b'\\') + .add(b']') + .add(b'^') + .add(b'_') + .add(b'`') + .add(b'{') + .add(b'|') + .add(b'}') + .add(b'~'); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn add_op() { + let left = AsciiSet::EMPTY.add(b'A'); + let right = AsciiSet::EMPTY.add(b'B'); + let expected = AsciiSet::EMPTY.add(b'A').add(b'B'); + assert_eq!(left + right, expected); + } + + #[test] + fn not_op() { + let set = AsciiSet::EMPTY.add(b'A').add(b'B'); + let not_set = !set; + assert!(!not_set.contains(b'A')); + assert!(not_set.contains(b'C')); + } + + /// This test ensures that we can get the union of two sets as a constant value, which is + /// useful for defining sets in a modular way. + #[test] + fn union() { + const A: AsciiSet = AsciiSet::EMPTY.add(b'A'); + const B: AsciiSet = AsciiSet::EMPTY.add(b'B'); + const UNION: AsciiSet = A.union(B); + const EXPECTED: AsciiSet = AsciiSet::EMPTY.add(b'A').add(b'B'); + assert_eq!(UNION, EXPECTED); + } + + /// This test ensures that we can get the complement of a set as a constant value, which is + /// useful for defining sets in a modular way. + #[test] + fn complement() { + const BOTH: AsciiSet = AsciiSet::EMPTY.add(b'A').add(b'B'); + const COMPLEMENT: AsciiSet = BOTH.complement(); + assert!(!COMPLEMENT.contains(b'A')); + assert!(!COMPLEMENT.contains(b'B')); + assert!(COMPLEMENT.contains(b'C')); + } +} diff --git a/percent_encoding/src/lib.rs b/percent_encoding/src/lib.rs index 2213943b..fba0cd91 100644 --- a/percent_encoding/src/lib.rs +++ b/percent_encoding/src/lib.rs @@ -51,165 +51,11 @@ use alloc::{ string::String, vec::Vec, }; -use core::{fmt, mem, ops, slice, str}; +use core::{fmt, slice, str}; -/// Represents a set of characters or bytes in the ASCII range. -/// -/// This is used in [`percent_encode`] and [`utf8_percent_encode`]. -/// This is similar to [percent-encode sets](https://url.spec.whatwg.org/#percent-encoded-bytes). -/// -/// Use the `add` method of an existing set to define a new set. For example: -/// -/// ``` -/// use percent_encoding::{AsciiSet, CONTROLS}; -/// -/// /// https://url.spec.whatwg.org/#fragment-percent-encode-set -/// const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); -/// ``` -#[derive(Debug, PartialEq, Eq)] -pub struct AsciiSet { - mask: [Chunk; ASCII_RANGE_LEN / BITS_PER_CHUNK], -} - -type Chunk = u32; - -const ASCII_RANGE_LEN: usize = 0x80; - -const BITS_PER_CHUNK: usize = 8 * mem::size_of::(); +pub use self::ascii_set::{AsciiSet, CONTROLS, NON_ALPHANUMERIC}; -impl AsciiSet { - /// An empty set. - pub const EMPTY: AsciiSet = AsciiSet { - mask: [0; ASCII_RANGE_LEN / BITS_PER_CHUNK], - }; - - /// Called with UTF-8 bytes rather than code points. - /// Not used for non-ASCII bytes. - const fn contains(&self, byte: u8) -> bool { - let chunk = self.mask[byte as usize / BITS_PER_CHUNK]; - let mask = 1 << (byte as usize % BITS_PER_CHUNK); - (chunk & mask) != 0 - } - - fn should_percent_encode(&self, byte: u8) -> bool { - !byte.is_ascii() || self.contains(byte) - } - - pub const fn add(&self, byte: u8) -> Self { - let mut mask = self.mask; - mask[byte as usize / BITS_PER_CHUNK] |= 1 << (byte as usize % BITS_PER_CHUNK); - AsciiSet { mask } - } - - pub const fn remove(&self, byte: u8) -> Self { - let mut mask = self.mask; - mask[byte as usize / BITS_PER_CHUNK] &= !(1 << (byte as usize % BITS_PER_CHUNK)); - AsciiSet { mask } - } - - /// Return the union of two sets. - pub const fn union(&self, other: Self) -> Self { - let mask = [ - self.mask[0] | other.mask[0], - self.mask[1] | other.mask[1], - self.mask[2] | other.mask[2], - self.mask[3] | other.mask[3], - ]; - AsciiSet { mask } - } - - /// Return the negation of the set. - pub const fn complement(&self) -> Self { - let mask = [!self.mask[0], !self.mask[1], !self.mask[2], !self.mask[3]]; - AsciiSet { mask } - } -} - -impl ops::Add for AsciiSet { - type Output = Self; - - fn add(self, other: Self) -> Self { - self.union(other) - } -} - -impl ops::Not for AsciiSet { - type Output = Self; - - fn not(self) -> Self { - self.complement() - } -} - -/// The set of 0x00 to 0x1F (C0 controls), and 0x7F (DEL). -/// -/// Note that this includes the newline and tab characters, but not the space 0x20. -/// -/// -pub const CONTROLS: &AsciiSet = &AsciiSet { - mask: [ - !0_u32, // C0: 0x00 to 0x1F (32 bits set) - 0, - 0, - 1 << (0x7F_u32 % 32), // DEL: 0x7F (one bit set) - ], -}; - -macro_rules! static_assert { - ($( $bool: expr, )+) => { - fn _static_assert() { - $( - let _ = mem::transmute::<[u8; $bool as usize], u8>; - )+ - } - } -} - -static_assert! { - CONTROLS.contains(0x00), - CONTROLS.contains(0x1F), - !CONTROLS.contains(0x20), - !CONTROLS.contains(0x7E), - CONTROLS.contains(0x7F), -} - -/// Everything that is not an ASCII letter or digit. -/// -/// This is probably more eager than necessary in any context. -pub const NON_ALPHANUMERIC: &AsciiSet = &CONTROLS - .add(b' ') - .add(b'!') - .add(b'"') - .add(b'#') - .add(b'$') - .add(b'%') - .add(b'&') - .add(b'\'') - .add(b'(') - .add(b')') - .add(b'*') - .add(b'+') - .add(b',') - .add(b'-') - .add(b'.') - .add(b'/') - .add(b':') - .add(b';') - .add(b'<') - .add(b'=') - .add(b'>') - .add(b'?') - .add(b'@') - .add(b'[') - .add(b'\\') - .add(b']') - .add(b'^') - .add(b'_') - .add(b'`') - .add(b'{') - .add(b'|') - .add(b'}') - .add(b'~'); +mod ascii_set; /// Return the percent-encoding of the given byte. /// @@ -292,7 +138,7 @@ pub fn utf8_percent_encode<'a>(input: &'a str, ascii_set: &'static AsciiSet) -> } /// The return type of [`percent_encode`] and [`utf8_percent_encode`]. -#[derive(Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct PercentEncode<'a> { bytes: &'a [u8], ascii_set: &'static AsciiSet, @@ -520,43 +366,116 @@ fn decode_utf8_lossy(input: Cow<'_, [u8]>) -> Cow<'_, str> { #[cfg(test)] mod tests { + use super::*; #[test] - fn add_op() { - let left = AsciiSet::EMPTY.add(b'A'); - let right = AsciiSet::EMPTY.add(b'B'); - let expected = AsciiSet::EMPTY.add(b'A').add(b'B'); - assert_eq!(left + right, expected); + fn percent_encode_byte() { + for i in 0..=0xFF { + let encoded = super::percent_encode_byte(i); + assert_eq!(encoded, alloc::format!("%{:02X}", i)); + } + } + + #[test] + fn percent_encode_accepts_ascii_set_ref() { + let encoded = percent_encode(b"foo bar?", &AsciiSet::EMPTY); + assert_eq!(encoded.collect::(), "foo bar?"); + } + + #[test] + fn percent_encode_collect() { + let encoded = percent_encode(b"foo bar?", NON_ALPHANUMERIC); + assert_eq!(encoded.collect::(), String::from("foo%20bar%3F")); + + let encoded = percent_encode(b"\x00\x01\x02\x03", CONTROLS); + assert_eq!(encoded.collect::(), String::from("%00%01%02%03")); + } + + #[test] + fn percent_encode_display() { + let encoded = percent_encode(b"foo bar?", NON_ALPHANUMERIC); + assert_eq!(alloc::format!("{}", encoded), "foo%20bar%3F"); + } + + #[test] + fn percent_encode_cow() { + let encoded = percent_encode(b"foo bar?", NON_ALPHANUMERIC); + assert_eq!(Cow::from(encoded), "foo%20bar%3F"); + } + + #[test] + fn utf8_percent_encode_accepts_ascii_set_ref() { + let encoded = super::utf8_percent_encode("foo bar?", &AsciiSet::EMPTY); + assert_eq!(encoded.collect::(), "foo bar?"); + } + + #[test] + fn utf8_percent_encode() { + assert_eq!( + super::utf8_percent_encode("foo bar?", NON_ALPHANUMERIC), + percent_encode(b"foo bar?", NON_ALPHANUMERIC) + ); + } + + #[test] + fn percent_decode() { + assert_eq!( + super::percent_decode(b"foo%20bar%3f") + .decode_utf8() + .unwrap(), + "foo bar?" + ); + } + + #[test] + fn percent_decode_str() { + assert_eq!( + super::percent_decode_str("foo%20bar%3f") + .decode_utf8() + .unwrap(), + "foo bar?" + ); + } + + #[test] + fn percent_decode_collect() { + let decoded = super::percent_decode(b"foo%20bar%3f"); + assert_eq!(decoded.collect::>(), b"foo bar?"); + } + + #[test] + fn percent_decode_cow() { + let decoded = super::percent_decode(b"foo%20bar%3f"); + assert_eq!(Cow::from(decoded), Cow::Owned::<[u8]>(b"foo bar?".to_vec())); + + let decoded = super::percent_decode(b"foo bar?"); + assert_eq!(Cow::from(decoded), Cow::Borrowed(b"foo bar?")); } #[test] - fn not_op() { - let set = AsciiSet::EMPTY.add(b'A').add(b'B'); - let not_set = !set; - assert!(!not_set.contains(b'A')); - assert!(not_set.contains(b'C')); + fn percent_decode_invalid_utf8() { + // Invalid UTF-8 sequence + let decoded = super::percent_decode(b"%00%9F%92%96") + .decode_utf8() + .unwrap_err(); + assert_eq!(decoded.valid_up_to(), 1); + assert_eq!(decoded.error_len(), Some(1)); } - /// This test ensures that we can get the union of two sets as a constant value, which is - /// useful for defining sets in a modular way. #[test] - fn union() { - const A: AsciiSet = AsciiSet::EMPTY.add(b'A'); - const B: AsciiSet = AsciiSet::EMPTY.add(b'B'); - const UNION: AsciiSet = A.union(B); - const EXPECTED: AsciiSet = AsciiSet::EMPTY.add(b'A').add(b'B'); - assert_eq!(UNION, EXPECTED); + fn percent_decode_utf8_lossy() { + assert_eq!( + super::percent_decode(b"%F0%9F%92%96").decode_utf8_lossy(), + "💖" + ); } - /// This test ensures that we can get the complement of a set as a constant value, which is - /// useful for defining sets in a modular way. #[test] - fn complement() { - const BOTH: AsciiSet = AsciiSet::EMPTY.add(b'A').add(b'B'); - const COMPLEMENT: AsciiSet = BOTH.complement(); - assert!(!COMPLEMENT.contains(b'A')); - assert!(!COMPLEMENT.contains(b'B')); - assert!(COMPLEMENT.contains(b'C')); + fn percent_decode_utf8_lossy_invalid_utf8() { + assert_eq!( + super::percent_decode(b"%00%9F%92%96").decode_utf8_lossy(), + "\u{0}���" + ); } }