From c35a832bac0c3178f3b8cc4416f4d446ae570eb7 Mon Sep 17 00:00:00 2001 From: grouville Date: Wed, 28 Feb 2024 01:58:09 -0800 Subject: [PATCH 01/13] feat: port randomize and sz_generate to Rust First raw function API design PS: still draft, having an issue on linker. I believe this is due to the fact that sz_generate is implemented in the header ? The FFI doesnt seem to work Nonetheless, this is a DX I envision. Signed-off-by: grouville --- Cargo.lock | 2 +- rust/lib.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 080c826a..e23a0a73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,7 +19,7 @@ checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" [[package]] name = "stringzilla" -version = "3.2.0" +version = "3.3.0" dependencies = [ "cc", ] diff --git a/rust/lib.rs b/rust/lib.rs index fff99606..fc866ab6 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -2,6 +2,8 @@ use core::ffi::c_void; +type SzRandomGeneratorT = extern "C" fn(*mut c_void) -> u64; + // Import the functions from the StringZilla C library. extern "C" { fn sz_find( @@ -64,6 +66,15 @@ extern "C" { gap: i8, allocator: *const c_void, ) -> isize; + + fn sz_generate( + alphabet: *const c_void, + alphabet_size: usize, + text: *mut c_void, + length: usize, + generate: SzRandomGeneratorT, + generator: *mut c_void, + ); } /// The [StringZilla] trait provides a collection of string searching and manipulation functionalities. @@ -99,6 +110,39 @@ where fn sz_alignment_score(&self, needle: N, matrix: [[i8; 256]; 256], gap: i8) -> isize; } +trait MutableStringZilla +where + N: AsRef<[u8]>, +{ + /// Generates a random string for a given alphabet. + /// Replaces the buffer with a random string of the same length. + // Cannot be String, as it is not AsMut<[u8]> + fn randomize(&mut self, alphabet: N, generate: SzRandomGeneratorT); +} + +impl MutableStringZilla for T +where + T: AsMut<[u8]>, + N: AsRef<[u8]>, +{ + fn randomize(&mut self, alphabet: N, generate: SzRandomGeneratorT) { + let text = self.as_mut(); + let text_len = text.len(); + let alphabet_slice = alphabet.as_ref(); // Convert N to &[u8]; + + unsafe { + sz_generate( + alphabet_slice.as_ptr() as *const c_void, + alphabet_slice.len(), + text.as_mut_ptr() as *mut c_void, + text_len, + generate, // Directly use the function pointer + core::ptr::null_mut(), // No need for a generator context + ); + } + } +} + impl StringZilla for T where T: AsRef<[u8]>, @@ -284,9 +328,10 @@ where #[cfg(test)] mod tests { + use core::ffi::c_void; use std::borrow::Cow; - use crate::StringZilla; + use crate::{MutableStringZilla, StringZilla, SzRandomGeneratorT}; fn unary_substitution_costs() -> [[i8; 256]; 256] { let mut result = [[0; 256]; 256]; @@ -300,6 +345,11 @@ mod tests { result } + // Define a simple deterministic generator for testing purposes. + extern "C" fn test_generator(_: *mut c_void) -> u64 { + 4 // Always returns 4 for predictability in tests + } + #[test] fn levenshtein() { assert_eq!("hello".sz_edit_distance("hell"), 1); @@ -359,4 +409,34 @@ mod tests { Some(12) ); } + + #[test] + fn test_randomize_with_byte_slice() { + let mut my_bytes: Vec = vec![0; 10]; // A buffer of ten zeros + let alphabet: &[u8] = b"abcd"; // A byte slice alphabet + my_bytes.randomize(alphabet, test_generator); + + // Assert that all bytes in `my_bytes` are now 'd' (ASCII 100), based on the test_generator + assert!(my_bytes.iter().all(|&b| b == b'd')); + } + + #[test] + fn test_randomize_with_vec() { + let mut my_bytes: Vec = vec![0; 10]; + let alphabet = vec![b'a', b'b', b'c', b'd']; // A Vec alphabet + my_bytes.randomize(&alphabet, test_generator); + + // Assert similar to the previous test + assert!(my_bytes.iter().all(|&b| b == b'd')); + } + + #[test] + fn test_randomize_with_string() { + let mut my_bytes: Vec = vec![0; 10]; + let alphabet = "abcd".to_string(); // A String alphabet + my_bytes.randomize(&alphabet, test_generator); + + // Assert similar to the previous test + assert!(my_bytes.iter().all(|&b| b == b'd')); + } } From f8d59d93bc547a679023b611bf6cf7a1a8749f0c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 28 Feb 2024 19:33:14 +0000 Subject: [PATCH 02/13] Fix: `Str() in Str()` checks in Python --- python/lib.c | 29 ++++++++++++++++++++++++++++- scripts/test.py | 2 ++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/python/lib.c b/python/lib.c index 5acbb752..40a0145f 100644 --- a/python/lib.c +++ b/python/lib.c @@ -801,7 +801,34 @@ static PyObject *Strs_subscript(Strs *self, PyObject *key) { } // Will be called by the `PySequence_Contains` -static int Strs_contains(Str *self, PyObject *arg) { return 0; } +static int Strs_contains(Str *self, PyObject *needle_obj) { + + // Validate and convert `needle` + sz_string_view_t needle; + if (!export_string_like(needle_obj, &needle.start, &needle.length)) { + PyErr_SetString(PyExc_TypeError, "The needle argument must be string-like"); + return NULL; + } + + // Depending on the layout, we will need to use different logic + Py_ssize_t count = Strs_len(self); + get_string_at_offset_t getter = str_at_offset_getter(self); + if (!getter) { + PyErr_SetString(PyExc_TypeError, "Unknown Strs kind"); + return NULL; + } + + // Time for a full-scan + for (Py_ssize_t i = 0; i < count; ++i) { + PyObject *parent = NULL; + char const *start = NULL; + size_t length = 0; + getter(self, i, count, &parent, &start, &length); + if (length == needle.length && sz_equal(start, needle.start, needle.length) == sz_true_k) return 1; + } + + return 0; +} static PyObject *Str_richcompare(PyObject *self, PyObject *other, int op) { diff --git a/scripts/test.py b/scripts/test.py index ee9225f0..e7b3ca6e 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -88,6 +88,8 @@ def test_unit_sequence(): lines = big.splitlines() assert [2, 1, 0] == list(lines.order()) + assert "p3" in lines + assert "p4" not in lines lines.sort() assert [0, 1, 2] == list(lines.order()) From 65564b9347a0f9874638d05524106946114d38db Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 28 Feb 2024 21:20:28 +0000 Subject: [PATCH 03/13] Improve: Faster rich comparisons --- include/stringzilla/stringzilla.h | 32 ++++-- python/lib.c | 173 ++++++++++++++++++++++++++++-- scripts/test.py | 62 ++++++++++- 3 files changed, 248 insertions(+), 19 deletions(-) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 5811b8ef..e39300a2 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -1722,21 +1722,34 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz #pragma GCC diagnostic pop } +/** + * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: + * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; + * for (; a != min_end; ++a, ++b) + * if (*a != *b) return ordering_lookup[*a < *b]; + * That, however, introduces a data-dependency. + * A cleaner option is to perform two comparisons and a subtraction. + * One instruction more, but no data-dependency. + */ +#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) + SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); sz_size_t min_length = a_shorter ? a_length : b_length; sz_cptr_t min_end = a + min_length; #if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { - a_vec.u64 = sz_u64_bytes_reverse(sz_u64_load(a).u64); - b_vec.u64 = sz_u64_bytes_reverse(sz_u64_load(b).u64); - if (a_vec.u64 != b_vec.u64) return ordering_lookup[a_vec.u64 < b_vec.u64]; + a_vec = sz_u64_load(a); + b_vec = sz_u64_load(b); + if (a_vec.u64 != b_vec.u64) + return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); } #endif for (; a != min_end; ++a, ++b) - if (*a != *b) return ordering_lookup[*a < *b]; - return a_length != b_length ? ordering_lookup[a_shorter] : sz_equal_k; + if (*a != *b) return _sz_order_scalars(*a, *b); + + // If the strings are equal up to `min_end`, then the shorter string is smaller + return _sz_order_scalars(a_length, b_length); } /** @@ -3890,7 +3903,6 @@ SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { } SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; sz_u512_vec_t a_vec, b_vec; __mmask64 a_mask, b_mask, mask_not_equal; @@ -3903,7 +3915,7 @@ SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr int first_diff = _tzcnt_u64(mask_not_equal); char a_char = a[first_diff]; char b_char = b[first_diff]; - return ordering_lookup[a_char < b_char]; + return _sz_order_scalars(a_char, b_char); } a += 64, b += 64, a_length -= 64, b_length -= 64; } @@ -3922,12 +3934,12 @@ SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr int first_diff = _tzcnt_u64(mask_not_equal); char a_char = a[first_diff]; char b_char = b[first_diff]; - return ordering_lookup[a_char < b_char]; + return _sz_order_scalars(a_char, b_char); } else // From logic perspective, the hardest cases are "abc\0" and "abc". // The result must be `sz_greater_k`, as the latter is shorter. - return a_length != b_length ? ordering_lookup[a_length < b_length] : sz_equal_k; + return _sz_order_scalars(a_length, b_length); } else return sz_equal_k; diff --git a/python/lib.c b/python/lib.c index 40a0145f..743aaef6 100644 --- a/python/lib.c +++ b/python/lib.c @@ -837,13 +837,7 @@ static PyObject *Str_richcompare(PyObject *self, PyObject *other, int op) { if (!export_string_like(self, &a_start, &a_length) || !export_string_like(other, &b_start, &b_length)) Py_RETURN_NOTIMPLEMENTED; - // Perform byte-wise comparison up to the minimum length - sz_size_t min_length = a_length < b_length ? a_length : b_length; - int order = memcmp(a_start, b_start, min_length); - - // If the strings are equal up to `min_length`, then the shorter string is smaller - if (order == 0) order = (a_length > b_length) - (a_length < b_length); - + int order = (int)sz_order(a_start, a_length, b_start, b_length); switch (op) { case Py_LT: return PyBool_FromLong(order < 0); case Py_LE: return PyBool_FromLong(order <= 0); @@ -855,6 +849,170 @@ static PyObject *Str_richcompare(PyObject *self, PyObject *other, int op) { } } +static PyObject *Strs_richcompare(PyObject *self, PyObject *other, int op) { + + Strs *a = (Strs *)self; + Py_ssize_t a_length = Strs_len(a); + get_string_at_offset_t a_getter = str_at_offset_getter(a); + if (!a_getter) { + PyErr_SetString(PyExc_TypeError, "Unknown Strs kind"); + return NULL; + } + + // If the other object is also a Strs, we can compare them much faster, + // avoiding the CPython API entirely + if (PyObject_TypeCheck(other, &StrsType)) { + Strs *b = (Strs *)other; + + // Check if lengths are equal + Py_ssize_t b_length = Strs_len(b); + if (a_length != b_length) { + if (op == Py_EQ) { Py_RETURN_FALSE; } + if (op == Py_NE) { Py_RETURN_TRUE; } + } + + // The second array may have a different layout + get_string_at_offset_t b_getter = str_at_offset_getter(b); + if (!b_getter) { + PyErr_SetString(PyExc_TypeError, "Unknown Strs kind"); + return NULL; + } + + // Check each item for equality + Py_ssize_t min_length = sz_min_of_two(a_length, b_length); + for (Py_ssize_t i = 0; i < min_length; i++) { + PyObject *ai_parent = NULL, *bi_parent = NULL; + char const *ai_start = NULL, *bi_start = NULL; + size_t ai_length = 0, bi_length = 0; + a_getter(a, i, a_length, &ai_parent, &ai_start, &ai_length); + b_getter(b, i, b_length, &bi_parent, &bi_start, &bi_length); + + // When dealing with arrays, early exists make sense only in some cases + int order = (int)sz_order(ai_start, ai_length, bi_start, bi_length); + switch (op) { + case Py_LT: + case Py_LE: + if (order > 0) { Py_RETURN_FALSE; } + break; + case Py_EQ: + if (order != 0) { Py_RETURN_FALSE; } + break; + case Py_NE: + if (order == 0) { Py_RETURN_TRUE; } + break; + case Py_GT: + case Py_GE: + if (order < 0) { Py_RETURN_FALSE; } + break; + default: break; + } + } + + // Prefixes are identical, compare lengths + switch (op) { + case Py_LT: return PyBool_FromLong(a_length < b_length); + case Py_LE: return PyBool_FromLong(a_length <= b_length); + case Py_EQ: return PyBool_FromLong(a_length == b_length); + case Py_NE: return PyBool_FromLong(a_length != b_length); + case Py_GT: return PyBool_FromLong(a_length > b_length); + case Py_GE: return PyBool_FromLong(a_length >= b_length); + default: Py_RETURN_NOTIMPLEMENTED; + } + } + + // The second argument is a sequence, but not a `Strs` object, + // so we need to iterate through it. + PyObject *other_iter = PyObject_GetIter(other); + if (!other_iter) { + PyErr_Clear(); + PyErr_SetString(PyExc_TypeError, "The second argument is not iterable"); + return NULL; + } + + // We may not even know the length of the second sequence, so + // let's just iterate as far as we can. + Py_ssize_t i = 0; + PyObject *other_item; + for (; (other_item = PyIter_Next(other_iter)); ++i) { + // Check if the second array is longer than the first + if (a_length <= i) { + Py_DECREF(other_item); + Py_DECREF(other_iter); + switch (op) { + case Py_LT: Py_RETURN_TRUE; + case Py_LE: Py_RETURN_TRUE; + case Py_EQ: Py_RETURN_FALSE; + case Py_NE: Py_RETURN_TRUE; + case Py_GT: Py_RETURN_FALSE; + case Py_GE: Py_RETURN_FALSE; + default: Py_RETURN_NOTIMPLEMENTED; + } + } + + // Try unpacking the element from the second sequence + sz_string_view_t bi; + if (!export_string_like(other_item, &bi.start, &bi.length)) { + Py_DECREF(other_item); + Py_DECREF(other_iter); + PyErr_SetString(PyExc_TypeError, "The second container must contain string-like objects"); + return NULL; + } + + // Both sequences aren't exhausted yet + PyObject *ai_parent = NULL; + char const *ai_start = NULL; + size_t ai_length = 0; + a_getter(a, i, a_length, &ai_parent, &ai_start, &ai_length); + + // When dealing with arrays, early exists make sense only in some cases + int order = (int)sz_order(ai_start, ai_length, bi.start, bi.length); + switch (op) { + case Py_LT: + case Py_LE: + if (order > 0) { + Py_DECREF(other_item); + Py_DECREF(other_iter); + Py_RETURN_FALSE; + } + break; + case Py_EQ: + if (order != 0) { + Py_DECREF(other_item); + Py_DECREF(other_iter); + Py_RETURN_FALSE; + } + break; + case Py_NE: + if (order == 0) { + Py_DECREF(other_item); + Py_DECREF(other_iter); + Py_RETURN_TRUE; + } + break; + case Py_GT: + case Py_GE: + if (order < 0) { + Py_DECREF(other_item); + Py_DECREF(other_iter); + Py_RETURN_FALSE; + } + break; + default: break; + } + } + + // The prefixes are equal and the second sequence is exhausted, but the first one may not be + switch (op) { + case Py_LT: return PyBool_FromLong(i < a_length); + case Py_LE: Py_RETURN_TRUE; + case Py_EQ: return PyBool_FromLong(i == a_length); + case Py_NE: return PyBool_FromLong(i != a_length); + case Py_GT: Py_RETURN_FALSE; + case Py_GE: return PyBool_FromLong(i == a_length); + default: Py_RETURN_NOTIMPLEMENTED; + } +} + /** * @brief Saves a StringZilla string to disk. */ @@ -2154,6 +2312,7 @@ static PyTypeObject StrsType = { .tp_methods = Strs_methods, .tp_as_sequence = &Strs_as_sequence, .tp_as_mapping = &Strs_as_mapping, + .tp_richcompare = Strs_richcompare, }; #pragma endregion diff --git a/scripts/test.py b/scripts/test.py index e7b3ca6e..d9775fea 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -41,14 +41,72 @@ def test_unit_contains(): assert "xxx" not in big -def test_unit_rich_comparisons(): +def test_unit_str_rich_comparisons(): + # Equality assert Str("aa") == "aa" + assert Str("a") != "b" + assert Str("abc") == Str("abc") + assert Str("abc") != Str("abd") + + # Less than and less than or equal to assert Str("aa") < "b" + assert Str("ab") <= "ab" + assert Str("a") < Str("b") + assert Str("abc") <= Str("abcd") + + # Greater than and greater than or equal to + assert Str("b") > "aa" + assert Str("ab") >= "ab" + assert Str("b") > Str("a") + assert Str("abcd") >= Str("abc") + + # Slicing and comparisons s2 = Str("abb") assert s2[1:] == "bb" assert s2[:-1] == "ab" assert s2[-1:] == "b" - + assert s2[1:] != "abb" + assert s2[:-2] == "a" + assert s2[-2:] == "bb" + + +def test_unit_strs_rich_comparisons(): + arr: Strs = Str("a b c d e f g h").split() + + # Test against another Strs object + identical_arr: Strs = Str("a b c d e f g h").split() + different_arr: Strs = Str("a b c d e f g i").split() + shorter_arr: Strs = Str("a b c d e").split() + longer_arr: Strs = Str("a b c d e f g h i j").split() + + assert arr == identical_arr + assert arr != different_arr + assert arr != shorter_arr + assert arr != longer_arr + assert shorter_arr < arr + assert longer_arr > arr + + # Test against a Python list and a tuple + list_equal = ["a", "b", "c", "d", "e", "f", "g", "h"] + list_different = ["a", "b", "c", "d", "x", "f", "g", "h"] + tuple_equal = ("a", "b", "c", "d", "e", "f", "g", "h") + tuple_different = ("a", "b", "c", "d", "e", "f", "g", "i") + + assert arr == list_equal + assert arr != list_different + assert arr == tuple_equal + assert arr != tuple_different + + # Test against a generator of unknown length + generator_equal = (x for x in "a b c d e f g h".split()) + generator_different = (x for x in "a b c d e f g i".split()) + generator_shorter = (x for x in "a b c d e".split()) + generator_longer = (x for x in "a b c d e f g h i j".split()) + + assert arr == generator_equal + assert arr != generator_different + assert arr != generator_shorter + assert arr != generator_longer def test_unit_buffer_protocol(): # Try importing NumPy to compute the Levenshtein distances more efficiently From fd48df95f94fc0e9578167b7ed4115938f796993 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 29 Feb 2024 00:59:06 +0000 Subject: [PATCH 04/13] Add: Python slices with steps for `Strs` --- python/lib.c | 282 +++++++++++++++++++++++++++++++++--------------- scripts/test.py | 100 ++++++++++++++--- 2 files changed, 277 insertions(+), 105 deletions(-) diff --git a/python/lib.c b/python/lib.c index 743aaef6..685df462 100644 --- a/python/lib.c +++ b/python/lib.c @@ -106,11 +106,14 @@ typedef struct { * offsets. The starting offset of the first element is zero bytes after the `start`. * Every chunk will include a separator of length `separator_length` at the end, except for the * last one. + * + * The layout isn't exactly identical to Arrow, as we have an optional separator and we have one less offset. + * https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-layout */ struct consecutive_slices_32bit_t { size_t count; size_t separator_length; - PyObject *parent; + PyObject *parent_string; char const *start; uint32_t *end_offsets; } consecutive_32bit; @@ -123,11 +126,14 @@ typedef struct { * offsets. The starting offset of the first element is zero bytes after the `start`. * Every chunk will include a separator of length `separator_length` at the end, except for the * last one. + * + * The layout isn't exactly identical to Arrow, as we have an optional separator and we have one less offset. + * https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-layout */ struct consecutive_slices_64bit_t { size_t count; size_t separator_length; - PyObject *parent; + PyObject *parent_string; char const *start; uint64_t *end_offsets; } consecutive_64bit; @@ -138,7 +144,7 @@ typedef struct { */ struct reordered_slices_t { size_t count; - PyObject *parent; + PyObject *parent_string; sz_string_view_t *parts; } reordered; @@ -243,29 +249,29 @@ sz_bool_t export_string_like(PyObject *object, sz_cptr_t **start, sz_size_t *len typedef void (*get_string_at_offset_t)(Strs *, Py_ssize_t, Py_ssize_t, PyObject **, char const **, size_t *); -void str_at_offset_consecutive_32bit(Strs *strs, Py_ssize_t i, Py_ssize_t count, PyObject **parent, char const **start, - size_t *length) { +void str_at_offset_consecutive_32bit(Strs *strs, Py_ssize_t i, Py_ssize_t count, // + PyObject **parent_string, char const **start, size_t *length) { uint32_t start_offset = (i == 0) ? 0 : strs->data.consecutive_32bit.end_offsets[i - 1]; uint32_t end_offset = strs->data.consecutive_32bit.end_offsets[i]; *start = strs->data.consecutive_32bit.start + start_offset; *length = end_offset - start_offset - strs->data.consecutive_32bit.separator_length * (i + 1 != count); - *parent = strs->data.consecutive_32bit.parent; + *parent_string = strs->data.consecutive_32bit.parent_string; } -void str_at_offset_consecutive_64bit(Strs *strs, Py_ssize_t i, Py_ssize_t count, PyObject **parent, char const **start, - size_t *length) { +void str_at_offset_consecutive_64bit(Strs *strs, Py_ssize_t i, Py_ssize_t count, // + PyObject **parent_string, char const **start, size_t *length) { uint64_t start_offset = (i == 0) ? 0 : strs->data.consecutive_64bit.end_offsets[i - 1]; uint64_t end_offset = strs->data.consecutive_64bit.end_offsets[i]; *start = strs->data.consecutive_64bit.start + start_offset; *length = end_offset - start_offset - strs->data.consecutive_64bit.separator_length * (i + 1 != count); - *parent = strs->data.consecutive_64bit.parent; + *parent_string = strs->data.consecutive_64bit.parent_string; } -void str_at_offset_reordered(Strs *strs, Py_ssize_t i, Py_ssize_t count, PyObject **parent, char const **start, - size_t *length) { +void str_at_offset_reordered(Strs *strs, Py_ssize_t i, Py_ssize_t count, // + PyObject **parent_string, char const **start, size_t *length) { *start = strs->data.reordered.parts[i].start; *length = strs->data.reordered.parts[i].length; - *parent = strs->data.reordered.parent; + *parent_string = strs->data.reordered.parent_string; } get_string_at_offset_t str_at_offset_getter(Strs *strs) { @@ -286,18 +292,18 @@ sz_bool_t prepare_strings_for_reordering(Strs *strs) { size_t count = 0; void *old_buffer = NULL; get_string_at_offset_t getter = NULL; - PyObject *parent = NULL; + PyObject *parent_string = NULL; switch (strs->type) { case STRS_CONSECUTIVE_32: count = strs->data.consecutive_32bit.count; old_buffer = strs->data.consecutive_32bit.end_offsets; - parent = strs->data.consecutive_32bit.parent; + parent_string = strs->data.consecutive_32bit.parent_string; getter = str_at_offset_consecutive_32bit; break; case STRS_CONSECUTIVE_64: count = strs->data.consecutive_64bit.count; old_buffer = strs->data.consecutive_64bit.end_offsets; - parent = strs->data.consecutive_64bit.parent; + parent_string = strs->data.consecutive_64bit.parent_string; getter = str_at_offset_consecutive_64bit; break; // Already in reordered form @@ -317,10 +323,10 @@ sz_bool_t prepare_strings_for_reordering(Strs *strs) { // Populate the new reordered array using get_string_at_offset for (size_t i = 0; i < count; ++i) { - PyObject *parent; + PyObject *parent_string; char const *start; size_t length; - getter(strs, (Py_ssize_t)i, count, &parent, &start, &length); + getter(strs, (Py_ssize_t)i, count, &parent_string, &start, &length); new_parts[i].start = start; new_parts[i].length = length; } @@ -332,7 +338,7 @@ sz_bool_t prepare_strings_for_reordering(Strs *strs) { strs->type = STRS_REORDERED; strs->data.reordered.count = count; strs->data.reordered.parts = new_parts; - strs->data.reordered.parent = parent; + strs->data.reordered.parent_string = parent_string; return 1; } @@ -593,6 +599,9 @@ static PyObject *Str_like_hash(PyObject *self, PyObject *args, PyObject *kwargs) return PyLong_FromSize_t((size_t)result); } +static PyObject *Str_get_address(Str *self, void *closure) { return PyLong_FromSize_t((sz_size_t)self->start); } +static PyObject *Str_get_nbytes(Str *self, void *closure) { return PyLong_FromSize_t(self->length); } + static Py_ssize_t Str_len(Str *self) { return self->length; } static PyObject *Str_getitem(Str *self, Py_ssize_t i) { @@ -680,6 +689,13 @@ static int Str_in(Str *self, PyObject *arg) { return sz_find(self->start, self->length, needle_struct.start, needle_struct.length) != NULL; } +static PyObject *Strs_get_tape(Str *self, void *closure) { return NULL; } +static PyObject *Strs_get_offsets_are_large(Str *self, void *closure) { return NULL; } +static PyObject *Strs_get_tape_address(Str *self, void *closure) { return NULL; } +static PyObject *Strs_get_offsets_address(Str *self, void *closure) { return NULL; } +static PyObject *Strs_get_tape_nbytes(Str *self, void *closure) { return NULL; } +static PyObject *Strs_get_offsets_nbytes(Str *self, void *closure) { return NULL; } + static Py_ssize_t Strs_len(Strs *self) { switch (self->type) { case STRS_CONSECUTIVE_32: return self->data.consecutive_32bit.count; @@ -721,83 +737,152 @@ static PyObject *Strs_getitem(Strs *self, Py_ssize_t i) { } static PyObject *Strs_subscript(Strs *self, PyObject *key) { - if (PySlice_Check(key)) { - // Sanity checks - Py_ssize_t count = Strs_len(self); - Py_ssize_t start, stop, step; - if (PySlice_Unpack(key, &start, &stop, &step) < 0) return NULL; - if (PySlice_AdjustIndices(count, &start, &stop, step) < 0) return NULL; - if (step != 1) { - PyErr_SetString(PyExc_IndexError, "Efficient step is not supported"); - return NULL; - } - // Create a new `Strs` object - Strs *self_slice = (Strs *)StrsType.tp_alloc(&StrsType, 0); - if (self_slice == NULL && PyErr_NoMemory()) return NULL; + if (PyLong_Check(key)) { return Strs_getitem(self, PyLong_AsSsize_t(key)); } - // Depending on the layout, the procedure will be different. - self_slice->type = self->type; - switch (self->type) { - -/* Usable as consecutive_logic(64bit), e.g. */ -#define consecutive_logic(type) \ - typedef index_##type##_t index_t; \ - typedef struct consecutive_slices_##type##_t slice_t; \ - slice_t *from = &self->data.consecutive_##type; \ - slice_t *to = &self_slice->data.consecutive_##type; \ - to->count = stop - start; \ - to->separator_length = from->separator_length; \ - to->parent = from->parent; \ - size_t first_length; \ - str_at_offset_consecutive_##type(self, start, count, &to->parent, &to->start, &first_length); \ - index_t first_offset = to->start - from->start; \ - to->end_offsets = malloc(sizeof(index_t) * to->count); \ - if (to->end_offsets == NULL && PyErr_NoMemory()) { \ - Py_XDECREF(self_slice); \ - return NULL; \ - } \ - for (size_t i = 0; i != to->count; ++i) to->end_offsets[i] = from->end_offsets[i + start] - first_offset; \ - Py_INCREF(to->parent); - case STRS_CONSECUTIVE_32: { - typedef uint32_t index_32bit_t; - consecutive_logic(32bit); - break; - } - case STRS_CONSECUTIVE_64: { - typedef uint64_t index_64bit_t; - consecutive_logic(64bit); - break; + if (!PySlice_Check(key)) { + PyErr_SetString(PyExc_TypeError, "Strs indices must be integers or slices"); + return NULL; + } + + // Sanity checks + Py_ssize_t count = Strs_len(self); + Py_ssize_t start, stop, step; + if (PySlice_Unpack(key, &start, &stop, &step) < 0) return NULL; + Py_ssize_t result_count = PySlice_AdjustIndices(count, &start, &stop, step); + if (result_count < 0) return NULL; + + // Create a new `Strs` object + Strs *result = (Strs *)StrsType.tp_alloc(&StrsType, 0); + if (result == NULL && PyErr_NoMemory()) return NULL; + if (result_count == 0) { + result->type = STRS_REORDERED; + result->data.reordered.count = 0; + result->data.reordered.parts = NULL; + result->data.reordered.parent_string = NULL; + return (PyObject *)result; + } + + // If a step is requested, we have to create a new `REORDERED` Strs object, + // even if the original one was `CONSECUTIVE`. + if (step != 1) { + sz_string_view_t *new_parts = (sz_string_view_t *)malloc(result_count * sizeof(sz_string_view_t)); + if (new_parts == NULL) { + Py_XDECREF(result); + PyErr_SetString(PyExc_MemoryError, "Unable to allocate memory for reordered slices"); + return 0; } -#undef consecutive_logic - case STRS_REORDERED: { - struct reordered_slices_t *from = &self->data.reordered; - struct reordered_slices_t *to = &self_slice->data.reordered; - to->count = stop - start; - to->parent = from->parent; - - to->parts = malloc(sizeof(sz_string_view_t) * to->count); - if (to->parts == NULL && PyErr_NoMemory()) { - Py_XDECREF(self_slice); - return NULL; + + get_string_at_offset_t getter = str_at_offset_getter(self); + result->type = STRS_REORDERED; + result->data.reordered.count = result_count; + result->data.reordered.parts = new_parts; + result->data.reordered.parent_string = NULL; + + // Populate the new reordered array using get_string_at_offset + size_t j = 0; + if (step > 0) + for (Py_ssize_t i = start; i < stop; i += step, ++j) { + getter(self, i, count, &result->data.reordered.parent_string, &new_parts[j].start, + &new_parts[j].length); } - memcpy(to->parts, from->parts + start, sizeof(sz_string_view_t) * to->count); - Py_INCREF(to->parent); - break; + else + for (Py_ssize_t i = start; i > stop; i += step, ++j) { + getter(self, i, count, &result->data.reordered.parent_string, &new_parts[j].start, + &new_parts[j].length); + } + + return (PyObject *)result; + } + + // Depending on the layout, the procedure will be different, but by now we know that: + // - `start` and `stop` are valid indices + // - `step` is 1 + // - `result_count` is positive + // - the resulting object will have the same type as the original one + result->type = self->type; + switch (self->type) { + + case STRS_CONSECUTIVE_32: { + typedef struct consecutive_slices_32bit_t consecutive_slices_t; + consecutive_slices_t *from = &self->data.consecutive_32bit; + consecutive_slices_t *to = &result->data.consecutive_32bit; + to->count = result_count; + + // Allocate memory for the end offsets + to->separator_length = from->separator_length; + to->end_offsets = malloc(sizeof(index_32bit_t) * result_count); + if (to->end_offsets == NULL && PyErr_NoMemory()) { + Py_XDECREF(result); + return NULL; + } + + // Now populate the offsets + size_t element_length; + str_at_offset_consecutive_32bit(self, start, count, &to->parent_string, &to->start, &element_length); + to->end_offsets[0] = element_length; + for (size_t i = 1; i < result_count; ++i) { + to->end_offsets[i - 1] += from->separator_length; + PyObject *element_parent = NULL; + char const *element_start = NULL; + str_at_offset_consecutive_32bit(self, start, count, &element_parent, &element_start, &element_length); + to->end_offsets[i] = element_length + to->end_offsets[i - 1]; } - default: - // Unsupported type - PyErr_SetString(PyExc_TypeError, "Unsupported type for conversion"); + Py_INCREF(to->parent_string); + break; + } + + case STRS_CONSECUTIVE_64: { + typedef struct consecutive_slices_64bit_t consecutive_slices_t; + consecutive_slices_t *from = &self->data.consecutive_64bit; + consecutive_slices_t *to = &result->data.consecutive_64bit; + to->count = result_count; + + // Allocate memory for the end offsets + to->separator_length = from->separator_length; + to->end_offsets = malloc(sizeof(index_64bit_t) * result_count); + if (to->end_offsets == NULL && PyErr_NoMemory()) { + Py_XDECREF(result); return NULL; } - return (PyObject *)self_slice; + // Now populate the offsets + size_t element_length; + str_at_offset_consecutive_64bit(self, start, count, &to->parent_string, &to->start, &element_length); + to->end_offsets[0] = element_length; + for (size_t i = 1; i < result_count; ++i) { + to->end_offsets[i - 1] += from->separator_length; + PyObject *element_parent = NULL; + char const *element_start = NULL; + str_at_offset_consecutive_64bit(self, start, count, &element_parent, &element_start, &element_length); + to->end_offsets[i] = element_length + to->end_offsets[i - 1]; + } + Py_INCREF(to->parent_string); + break; } - else if (PyLong_Check(key)) { return Strs_getitem(self, PyLong_AsSsize_t(key)); } - else { - PyErr_SetString(PyExc_TypeError, "Strs indices must be integers or slices"); + + case STRS_REORDERED: { + struct reordered_slices_t *from = &self->data.reordered; + struct reordered_slices_t *to = &result->data.reordered; + to->count = result_count; + to->parent_string = from->parent_string; + + to->parts = malloc(sizeof(sz_string_view_t) * to->count); + if (to->parts == NULL && PyErr_NoMemory()) { + Py_XDECREF(result); + return NULL; + } + memcpy(to->parts, from->parts + start, sizeof(sz_string_view_t) * to->count); + Py_INCREF(to->parent_string); + break; + } + default: + // Unsupported type + PyErr_SetString(PyExc_TypeError, "Unsupported type for conversion"); return NULL; } + + return (PyObject *)result; } // Will be called by the `PySequence_Contains` @@ -1712,7 +1797,7 @@ static PyObject *Str_find_last_not_of(PyObject *self, PyObject *args, PyObject * return PyLong_FromSsize_t(signed_offset); } -static Strs *Str_split_(PyObject *parent, sz_string_view_t text, sz_string_view_t separator, int keepseparator, +static Strs *Str_split_(PyObject *parent_string, sz_string_view_t text, sz_string_view_t separator, int keepseparator, Py_ssize_t maxsplit) { // Create Strs object Strs *result = (Strs *)PyObject_New(Strs, &StrsType); @@ -1727,14 +1812,14 @@ static Strs *Str_split_(PyObject *parent, sz_string_view_t text, sz_string_view_ bytes_per_offset = 8; result->type = STRS_CONSECUTIVE_64; result->data.consecutive_64bit.start = text.start; - result->data.consecutive_64bit.parent = parent; + result->data.consecutive_64bit.parent_string = parent_string; result->data.consecutive_64bit.separator_length = !keepseparator * separator.length; } else { bytes_per_offset = 4; result->type = STRS_CONSECUTIVE_32; result->data.consecutive_32bit.start = text.start; - result->data.consecutive_32bit.parent = parent; + result->data.consecutive_32bit.parent_string = parent_string; result->data.consecutive_32bit.separator_length = !keepseparator * separator.length; } @@ -1781,7 +1866,7 @@ static Strs *Str_split_(PyObject *parent, sz_string_view_t text, sz_string_view_ result->data.consecutive_32bit.count = offsets_count; } - Py_INCREF(parent); + Py_INCREF(parent_string); return result; } @@ -1976,6 +2061,13 @@ static PyNumberMethods Str_as_number = { .nb_add = Str_concat, }; +static PyGetSetDef Str_getsetters[] = { + // Compatibility with PyArrow + {"address", (getter)Str_get_address, NULL, "Get the memory address of the first byte of the string", NULL}, + {"nbytes", (getter)Str_get_nbytes, NULL, "Get the length of the string in bytes", NULL}, + {NULL} // Sentinel +}; + #define SZ_METHOD_FLAGS METH_VARARGS | METH_KEYWORDS static PyMethodDef Str_methods[] = { @@ -2040,6 +2132,7 @@ static PyTypeObject StrType = { .tp_as_mapping = &Str_as_mapping, .tp_as_buffer = &Str_as_buffer, .tp_as_number = &Str_as_number, + .tp_getset = Str_getsetters, }; #pragma endregion @@ -2296,6 +2389,18 @@ static PyMappingMethods Strs_as_mapping = { .mp_subscript = Strs_subscript, // Is used to implement slices in Python }; +static PyGetSetDef Strs_getsetters[] = { + // Compatibility with PyArrow + {"tape", (getter)Strs_get_tape, NULL, "In-place transforms the string representation to match Apache Arrow", NULL}, + {"tape_address", (getter)Strs_get_tape_address, NULL, "Address of the first byte of the first string", NULL}, + {"tape_nbytes", (getter)Strs_get_tape_nbytes, NULL, "Length of the entire tape of strings in bytes", NULL}, + {"offsets_address", (getter)Strs_get_offsets_address, NULL, "Address of the first byte of offsets array", NULL}, + {"offsets_nbytes", (getter)Strs_get_offsets_nbytes, NULL, "Get teh length of offsets array in bytes", NULL}, + {"offsets_are_large", (getter)Strs_get_offsets_are_large, NULL, + "Checks if 64-bit addressing should be used to convert to Arrow", NULL}, + {NULL} // Sentinel +}; + static PyMethodDef Strs_methods[] = { {"shuffle", Strs_shuffle, SZ_METHOD_FLAGS, "Shuffle the elements of the Strs object."}, // {"sort", Strs_sort, SZ_METHOD_FLAGS, "Sort the elements of the Strs object."}, // @@ -2312,6 +2417,7 @@ static PyTypeObject StrsType = { .tp_methods = Strs_methods, .tp_as_sequence = &Strs_as_sequence, .tp_as_mapping = &Strs_as_mapping, + .tp_getset = Strs_getsetters, .tp_richcompare = Strs_richcompare, }; diff --git a/scripts/test.py b/scripts/test.py index d9775fea..8df068af 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -7,6 +7,31 @@ import stringzilla as sz from stringzilla import Str, Strs +# NumPy is available on most platforms and is required for most tests. +# When using PyPy on some platforms NumPy has internal issues, that will +# raise a weird error, not an `ImportError`. That's why we intentionally +# use a naked `except:`. Necessary evil! +try: + import numpy as np + + numpy_available = True +except: + # NumPy is not installed, most tests will be skipped + numpy_available = False + + +# PyArrow is not available on most platforms. +# When using PyPy on some platforms PyArrow has internal issues, that will +# raise a weird error, not an `ImportError`. That's why we intentionally +# use a naked `except:`. Necessary evil! +try: + import pyarrow as pa + + pyarrow_available = True +except: + # NumPy is not installed, most tests will be skipped + pyarrow_available = False + def test_library_properties(): assert len(sz.__version__.split(".")) == 3, "Semantic versioning must be preserved" @@ -108,13 +133,9 @@ def test_unit_strs_rich_comparisons(): assert arr != generator_shorter assert arr != generator_longer -def test_unit_buffer_protocol(): - # Try importing NumPy to compute the Levenshtein distances more efficiently - try: - import numpy as np - except ImportError: - pytest.skip("NumPy is not installed") +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") +def test_unit_buffer_protocol(): my_str = Str("hello") arr = np.array(my_str) assert arr.dtype == np.dtype("c") @@ -140,7 +161,7 @@ def test_unit_split(): assert str(parts[2]) == "token3" -def test_unit_sequence(): +def test_unit_strs_sequence(): native = "p3\np2\np1" big = Str(native) @@ -159,6 +180,48 @@ def test_unit_sequence(): assert ["p3", "p2", "p1"] == list(lines) +def test_unit_slicing(): + native = "abcdef" + big = Str(native) + assert big[1:3] == "bc" + assert big[1:] == "bcdef" + assert big[:3] == "abc" + assert big[-1:] == "f" + assert big[:-1] == "abcde" + assert big[-3:] == "def" + assert big[:-3] == "abc" + + +def test_unit_strs_sequence_slicing(): + native = "1, 2, 3, 4, 5, 6" + big = Str(native) + big_sequence = big.split(", ") + + def to_str(seq): + return "".join([str(x) for x in seq]) + + assert big_sequence[1:3] == ["2", "3"], to_str(big_sequence[1:3]) + assert big_sequence[1:] == ["2", "3", "4", "5", "6"], to_str(big_sequence[1:]) + assert big_sequence[:3] == ["1", "2", "3"], to_str(big_sequence[:3]) + + # Use negative indices to slice from the end + assert big_sequence[-1:] == ["6"], to_str(big_sequence[-1:]) + assert big_sequence[:-1] == ["1", "2", "3", "4", "5"], to_str(big_sequence[:-1]) + assert big_sequence[-3:] == ["4", "5", "6"], to_str(big_sequence[-3:]) + assert big_sequence[:-3] == ["1", "2", "3"], to_str(big_sequence[:-3]) + + # Introduce a step to skip some values + assert big_sequence[::2] == ["1", "3", "5"], to_str(big_sequence[::2]) + assert big_sequence[::-1] == ["6", "5", "4", "3", "2", "1"], to_str( + big_sequence[::-1] + ) + + # Let's go harder with nested slicing + assert big_sequence[1:][::-1] == ["6", "5", "4", "3", "2"] + assert big_sequence[1:][::-2] == ["6", "4", "2"] + assert big_sequence[1:][::-3] == ["6", "3"] + + def test_unit_globals(): """Validates that the previously unit-tested member methods are also visible as global functions.""" @@ -222,15 +285,11 @@ def is_equal_strings(native_strings, big_strings): ), f"Mismatch between `{native_slice}` and `{str(big_slice)}`" +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") def baseline_edit_distance(s1, s2) -> int: """ Compute the Levenshtein distance between two strings. """ - # Try importing NumPy to compute the Levenshtein distances more efficiently - try: - import numpy as np - except ImportError: - pytest.skip("NumPy is not installed") # Create a matrix of size (len(s1)+1) x (len(s2)+1) matrix = np.zeros((len(s1) + 1, len(s2) + 1), dtype=int) @@ -405,12 +464,8 @@ def test_edit_distance_random(first_length: int, second_length: int): @pytest.mark.repeat(30) @pytest.mark.parametrize("first_length", [20, 100]) @pytest.mark.parametrize("second_length", [20, 100]) +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") def test_alignment_score_random(first_length: int, second_length: int): - # Try importing NumPy to compute the Levenshtein distances more efficiently - try: - import numpy as np - except ImportError: - pytest.skip("NumPy is not installed") a = get_random_string(length=first_length) b = get_random_string(length=second_length) @@ -480,3 +535,14 @@ def test_fuzzy_sorting(list_length: int, part_length: int, variability: int): assert len(native_list) == len(big_list) for native_str, big_str in zip(native_list, big_list): assert native_str == str(big_str), "Order is wrong" + + +@pytest.mark.skipif(not pyarrow_available, reason="PyArrow is not installed") +def test_pyarrow_str_conversion(): + native = "hello" + big = Str(native) + assert isinstance(big.address, int) and big.address != 0 + assert isinstance(big.nbytes, int) and big.nbytes == len(native) + + arrow_buffer = pa.foreign_buffer(big.address, big.nbytes, big) + assert arrow_buffer.to_pybytes() == native.encode("utf-8") From 0e5c2f56ff9ebd1b67f257bcbf7051b96c54d9b7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 29 Feb 2024 04:17:47 +0000 Subject: [PATCH 05/13] Add: `Strs.sample()` functionality Random-sampling a large file is now 10x faster. 2 secs to sample 1000 rows from 1 GB wiki with Python. 200 ms to do the sample with `sz.splitlines().sample()`. --- README.md | 34 ++++++++++++++++- python/lib.c | 98 ++++++++++++++++++++++++++++++++++++++++++++++--- scripts/test.py | 5 +++ 3 files changed, 130 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index acd7bcff..4e5c00e9 100644 --- a/README.md +++ b/README.md @@ -354,14 +354,29 @@ Assuming superior search speed splitting should also work 3x faster than with na Need copies? ```python -sorted_copy: Strs = lines.sorted() -shuffled_copy: Strs = lines.shuffled(seed=42) +words = Str('b a c').split() +sorted_copy: Strs = words.sorted() +assert words == ['a', 'b', 'c'] # Supports all comparison operators +shuffled_copy: Strs = words.shuffled(seed=42) ``` Those collections of `Strs` are designed to keep the memory consumption low. If all the chunks are located in consecutive memory regions, the memory overhead can be as low as 4 bytes per chunk. That's designed to handle very large datasets, like [RedPajama][redpajama]. To address all 20 Billion annotated english documents in it, one will need only 160 GB of RAM instead of Terabytes. +Once memory mapped, you can random-sample the data without loading it into RAM: + +```python +lines.sample(seed=42) # 10x faster than `random.choices` +``` + +Or you can batch and shard the data using slicing in any order with any steps: + +```python +lines[::3] # every third line +lines[1::1] # every odd line +lines[:-100:-1] # last 100 lines in reverse order +``` [redpajama]: https://github.com/togethercomputer/RedPajama-Data @@ -454,6 +469,21 @@ assert sz.alignment_score( +### Serialization + +#### Filesystem + +Similar to how `File` can be used to read a large file, other interfaces can be used to dump strings to disk faster. +The `Str` class has `write_to` to write the string to a file, and `offset_within` to obtain integer offsets of substring view in larger string for navigation. + +```py +web_archieve = Str("......") +_, end_tag, next_doc = web_archieve.partition("") # or use `find` +next_doc_offset = next_doc.offset_within(web_archieve) +web_archieve.write_to("next_doc.html") +``` + + ## Quick Start: C/C++ 🛠️ The C library is header-only, so you can just copy the `stringzilla.h` header into your project. diff --git a/python/lib.c b/python/lib.c index 685df462..6b556b60 100644 --- a/python/lib.c +++ b/python/lib.c @@ -36,7 +36,9 @@ typedef SSIZE_T ssize_t; #include // Core CPython interfaces #include // `fopen` +#include // `rand`, `srand` #include // `memset`, `memcpy` +#include // `time` #include @@ -811,7 +813,7 @@ static PyObject *Strs_subscript(Strs *self, PyObject *key) { // Allocate memory for the end offsets to->separator_length = from->separator_length; - to->end_offsets = malloc(sizeof(index_32bit_t) * result_count); + to->end_offsets = malloc(sizeof(uint32_t) * result_count); if (to->end_offsets == NULL && PyErr_NoMemory()) { Py_XDECREF(result); return NULL; @@ -840,7 +842,7 @@ static PyObject *Strs_subscript(Strs *self, PyObject *key) { // Allocate memory for the end offsets to->separator_length = from->separator_length; - to->end_offsets = malloc(sizeof(index_64bit_t) * result_count); + to->end_offsets = malloc(sizeof(uint64_t) * result_count); if (to->end_offsets == NULL && PyErr_NoMemory()) { Py_XDECREF(result); return NULL; @@ -2378,6 +2380,90 @@ static PyObject *Strs_order(Strs *self, PyObject *args, PyObject *kwargs) { return tuple; } +static PyObject *Strs_sample(Strs *self, PyObject *args, PyObject *kwargs) { + PyObject *seed_obj = NULL; + PyObject *sample_size_obj = NULL; + + // Check for positional arguments + Py_ssize_t nargs = PyTuple_Size(args); + if (nargs > 1) { + PyErr_SetString(PyExc_TypeError, "sample() takes 1 positional argument and 1 keyword argument"); + return NULL; + } + else if (nargs == 1) { sample_size_obj = PyTuple_GET_ITEM(args, 0); } + + // Parse keyword arguments + if (kwargs) { + Py_ssize_t pos = 0; + PyObject *key, *value; + while (PyDict_Next(kwargs, &pos, &key, &value)) { + if (PyUnicode_CompareWithASCIIString(key, "seed") == 0) { seed_obj = value; } + else { + PyErr_Format(PyExc_TypeError, "Got an unexpected keyword argument '%U'", key); + return 0; + } + } + } + + // Translate the seed and the sample size to C types + size_t sample_size = 0; + if (sample_size_obj) { + if (!PyLong_Check(sample_size_obj)) { + PyErr_SetString(PyExc_TypeError, "The sample size must be an integer"); + return NULL; + } + sample_size = PyLong_AsSize_t(sample_size_obj); + } + unsigned int seed = time(NULL); // Default seed + if (seed_obj) { + if (!PyLong_Check(seed_obj)) { + PyErr_SetString(PyExc_TypeError, "The seed must be an integer"); + return NULL; + } + seed = PyLong_AsUnsignedLong(seed_obj); + } + + // Create a new `Strs` object + Strs *result = (Strs *)StrsType.tp_alloc(&StrsType, 0); + if (result == NULL && PyErr_NoMemory()) return NULL; + + result->type = STRS_REORDERED; + result->data.reordered.count = 0; + result->data.reordered.parts = NULL; + result->data.reordered.parent_string = NULL; + if (sample_size == 0) { return (PyObject *)result; } + + // Now create a new Strs object with the sampled strings + sz_string_view_t *result_parts = malloc(sample_size * sizeof(sz_string_view_t)); + if (!result_parts) { + PyErr_SetString(PyExc_MemoryError, "Failed to allocate memory for the sample"); + return NULL; + } + + // Introspect the Strs object to know the from which will be sampling + Py_ssize_t count = Strs_len(self); + get_string_at_offset_t getter = str_at_offset_getter(self); + if (!getter) { + PyErr_SetString(PyExc_TypeError, "Unknown Strs kind"); + return NULL; + } + + // Randomly sample the strings + srand(seed); + PyObject *parent_string; + for (Py_ssize_t i = 0; i < sample_size; i++) { + size_t index = rand() % count; + getter(self, index, count, &parent_string, &result_parts[i].start, &result_parts[i].length); + } + + // Update the Strs object + result->type = STRS_REORDERED; + result->data.reordered.count = sample_size; + result->data.reordered.parts = result_parts; + result->data.reordered.parent_string = parent_string; + return result; +} + static PySequenceMethods Strs_as_sequence = { .sq_length = Strs_len, // .sq_item = Strs_getitem, // @@ -2402,9 +2488,11 @@ static PyGetSetDef Strs_getsetters[] = { }; static PyMethodDef Strs_methods[] = { - {"shuffle", Strs_shuffle, SZ_METHOD_FLAGS, "Shuffle the elements of the Strs object."}, // - {"sort", Strs_sort, SZ_METHOD_FLAGS, "Sort the elements of the Strs object."}, // - {"order", Strs_order, SZ_METHOD_FLAGS, "Provides the indexes to achieve sorted order."}, // + {"shuffle", Strs_shuffle, SZ_METHOD_FLAGS, "Shuffle (in-place) the elements of the Strs object."}, // + {"sort", Strs_sort, SZ_METHOD_FLAGS, "Sort (in-place) the elements of the Strs object."}, // + {"order", Strs_order, SZ_METHOD_FLAGS, "Provides the indexes to achieve sorted order."}, // + {"sample", Strs_sample, SZ_METHOD_FLAGS, "Provides a random sample of a given size."}, // + // {"to_pylist", Strs_to_pylist, SZ_METHOD_FLAGS, "Exports string-views to a native list of native strings."}, // {NULL, NULL, 0, NULL}}; static PyTypeObject StrsType = { diff --git a/scripts/test.py b/scripts/test.py index 8df068af..01ed4e2b 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -179,6 +179,11 @@ def test_unit_strs_sequence(): lines.sort(reverse=True) assert ["p3", "p2", "p1"] == list(lines) + # Sampling an array + sampled = lines.sample(100, seed=42) + assert "p3" in sampled + assert "p4" not in sampled + def test_unit_slicing(): native = "abcdef" From 3b6cdddbcba50b326eb44fa799381b978d99bdc5 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 1 Mar 2024 22:41:40 +0000 Subject: [PATCH 06/13] Add: Lazy iterators for Python --- CONTRIBUTING.md | 6 + README.md | 76 ++++-- python/lib.c | 557 +++++++++++++++++++++++++++++++++++++------- scripts/bench.ipynb | 320 +++++++++++-------------- scripts/test.py | 225 ++++++++++++++---- 5 files changed, 853 insertions(+), 331 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a844cbdc..75bc203e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -136,6 +136,12 @@ cppcheck --project=build_artifacts/compile_commands.json --enable=all clang-tidy-11 -p build_artifacts ``` +I'd recommend putting the following breakpoints: + +- `__asan::ReportGenericError` - to detect illegal memory accesses. +- `__GI_exit` - to stop at exit points - the end of running any executable. +- `__builtin_unreachable` - to catch all the places where the code is expected to be unreachable. + ### Benchmarking For benchmarks, you can use the following commands: diff --git a/README.md b/README.md index 4e5c00e9..b4905bdd 100644 --- a/README.md +++ b/README.md @@ -337,40 +337,48 @@ A standard dataset pre-processing use case would be to map a sizeable textual da - `text.contains('substring', start=0, end=9223372036854775807) -> bool` - `text.find('substring', start=0, end=9223372036854775807) -> int` - `text.count('substring', start=0, end=9223372036854775807, allowoverlap=False) -> int` -- `text.splitlines(keeplinebreaks=False, separator='\n') -> Strs` - `text.split(separator=' ', maxsplit=9223372036854775807, keepseparator=False) -> Strs` +- `text.rsplit(separator=' ', maxsplit=9223372036854775807, keepseparator=False) -> Strs` +- `text.splitlines(keeplinebreaks=False, maxsplit=9223372036854775807) -> Strs` -### Collection-Level Operations +It's important to note, that the last function behavior is slightly different from Python's `str.splitlines`. +The [native version][faq-splitlines] matches `\n`, `\r`, `\v` or `\x0b`, `\f` or `\x0c`, `\x1c`, `\x1d`, `\x1e`, `\x85`, `\r\n`, `\u2028`, `\u2029`, including 3x two-bytes-long runes. +The StringZilla version matches only `\n` and is practically a shortcut for `text.split('\n')`. -Once split into a `Strs` object, you can sort, shuffle, and reorganize the slices. +[faq-splitlines]: https://docs.python.org/3/library/stdtypes.html#str.splitlines -```python -lines: Strs = text.split(separator='\n') # 4 bytes per line overhead for under 4 GB of text -lines.sort() # explodes to 16 bytes per line overhead for any length text -lines.shuffle(seed=42) # reproducing dataset shuffling with a seed -``` +### Character Set Operations -Assuming superior search speed splitting should also work 3x faster than with native Python strings. -Need copies? +Python strings don't natively support character set operations. +This forces people to use regular expressions, which are slow and hard to read. +To avoid the need for `re.finditer`, StringZilla provides the following interfaces: -```python -words = Str('b a c').split() -sorted_copy: Strs = words.sorted() -assert words == ['a', 'b', 'c'] # Supports all comparison operators -shuffled_copy: Strs = words.shuffled(seed=42) -``` +- `text.find_first_of('chars', start=0, end=9223372036854775807) -> int` +- `text.find_last_of('chars', start=0, end=9223372036854775807) -> int` +- `text.find_first_not_of('chars', start=0, end=9223372036854775807) -> int` +- `text.find_last_not_of('chars', start=0, end=9223372036854775807) -> int` + +Similarly, for splitting operations: + +- `text.split_charset(separator='chars', maxsplit=9223372036854775807, keepseparator=False) -> Strs` +- `text.rsplit_charset(separator='chars', maxsplit=9223372036854775807, keepseparator=False) -> Strs` + +### Collection-Level Operations -Those collections of `Strs` are designed to keep the memory consumption low. +Once split into a `Strs` object, you can sort, shuffle, and reorganize the slices, with minimum memory footprint. If all the chunks are located in consecutive memory regions, the memory overhead can be as low as 4 bytes per chunk. -That's designed to handle very large datasets, like [RedPajama][redpajama]. -To address all 20 Billion annotated english documents in it, one will need only 160 GB of RAM instead of Terabytes. -Once memory mapped, you can random-sample the data without loading it into RAM: ```python -lines.sample(seed=42) # 10x faster than `random.choices` +lines: Strs = text.split(separator='\n') # 4 bytes per line overhead for under 4 GB of text +batch: Strs = lines.sample(seed=42) # 10x faster than `random.choices` +lines.shuffle(seed=42) # or shuffle all lines in place and shard with slices +# WIP: lines.sort() # explodes to 16 bytes per line overhead for any length text +# WIP: sorted_order: tuple = lines.argsort() # similar to `numpy.argsort` ``` -Or you can batch and shard the data using slicing in any order with any steps: +Working on [RedPajama][redpajama], addressing 20 Billion annotated english documents, one will need only 160 GB of RAM instead of Terabytes. +Once loaded, the data will be memory-mapped, and can be reused between multiple Python processes without copies. +And of course, you can use slices to navigate the dataset and shard it between multiple workers. ```python lines[::3] # every third line @@ -380,6 +388,30 @@ lines[:-100:-1] # last 100 lines in reverse order [redpajama]: https://github.com/togethercomputer/RedPajama-Data +### Iterators and Memory Efficiency + +Python's operations like `split()` and `readlines()` immediately materialize a `list` of copied parts. +This can be very memory-inefficient for large datasets. +StringZilla saves a lot of memory by viewing existing memory regions as substrings, but even more memory can be saved by using lazily evaluated iterators. + +- `text.split_iter(separator=' ', keepseparator=False) -> SplitIterator[Str]` +- `text.rsplit_iter(separator=' ', keepseparator=False) -> SplitIterator[Str]` +- `text.split_charset_iter(separator='chars', keepseparator=False) -> SplitIterator[Str]` +- `text.rsplit_charset_iter(separator='chars', keepseparator=False) -> SplitIterator[Str]` + +StringZilla can easily be 10x more memory efficient than native Python classes for tokenization. +With lazy operations, it practically becomes free. + +```py +import stringzilla as sz +%load_ext memory_profiler + +text = open("enwik9.txt", "r").read() # 1 GB, mean word length 7.73 bytes +%memit text.split() # increment: 8670.12 MiB (152 ms) +%memit sz.split(text) # increment: 530.75 MiB (25 ms) +%memit sum(1 for _ in sz.split_iter(text)) # increment: 0.00 MiB +``` + ### Low-Level Python API Aside from calling the methods on the `Str` and `Strs` classes, you can also call the global functions directly on `str` and `bytes` instances. diff --git a/python/lib.c b/python/lib.c index 6b556b60..ac11741d 100644 --- a/python/lib.c +++ b/python/lib.c @@ -47,6 +47,7 @@ typedef SSIZE_T ssize_t; static PyTypeObject FileType; static PyTypeObject StrType; static PyTypeObject StrsType; +static PyTypeObject SplitIteratorType; static sz_string_view_t temporary_memory = {NULL, 0}; @@ -55,12 +56,12 @@ static sz_string_view_t temporary_memory = {NULL, 0}; * native `mmap` module, as it exposes the address of the mapping in memory. */ typedef struct { - PyObject_HEAD + PyObject_HEAD; #if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) - HANDLE file_handle; + HANDLE file_handle; HANDLE mapping_handle; #else - int file_descriptor; + int file_descriptor; #endif sz_cptr_t start; sz_size_t length; @@ -79,25 +80,59 @@ typedef struct { * - Str(File("some-path.txt"), from=0, to=sys.maxint) */ typedef struct { - PyObject_HEAD // - PyObject *parent; + PyObject_HEAD; + PyObject *parent; sz_cptr_t start; sz_size_t length; } Str; +/** + * @brief String-splitting separator. + * + * Allows lazy evaluation of the `split` and `rsplit`, and can be used to create a `Strs` object. + * which might be more memory-friendly, than greedily invoking `str.split`. + */ +typedef struct { + PyObject_HEAD; + + PyObject *text_object; //< For reference counting + PyObject *separator_object; //< For reference counting + + sz_string_view_t text; + sz_string_view_t separator; + sz_find_t finder; + + /// @brief How many bytes to skip after each successful find. + /// Generally equal to `needle_length`, or 1 for character sets. + sz_size_t match_length; + + /// @brief Should we include the separator in the resulting slices? + sz_bool_t include_match; + + /// @brief Should we enumerate the slices in normal or reverse order? + sz_bool_t is_reverse; + + /// @brief Upper limit for the number of splits to report. Monotonically decreases during iteration. + sz_size_t max_parts; + + /// @brief Indicates that we've already reported the tail of the split, and should return NULL next. + sz_bool_t reached_tail; + +} SplitIterator; + /** * @brief Variable length Python object similar to `Tuple[Union[Str, str]]`, * for faster sorting, shuffling, joins, and lookups. */ typedef struct { - PyObject_HEAD + PyObject_HEAD; - enum { - STRS_CONSECUTIVE_32, - STRS_CONSECUTIVE_64, - STRS_REORDERED, - STRS_MULTI_SOURCE, - } type; + enum { + STRS_CONSECUTIVE_32, + STRS_CONSECUTIVE_64, + STRS_REORDERED, + STRS_MULTI_SOURCE, + } type; union { /** @@ -254,18 +289,20 @@ typedef void (*get_string_at_offset_t)(Strs *, Py_ssize_t, Py_ssize_t, PyObject void str_at_offset_consecutive_32bit(Strs *strs, Py_ssize_t i, Py_ssize_t count, // PyObject **parent_string, char const **start, size_t *length) { uint32_t start_offset = (i == 0) ? 0 : strs->data.consecutive_32bit.end_offsets[i - 1]; - uint32_t end_offset = strs->data.consecutive_32bit.end_offsets[i]; + uint32_t end_offset = strs->data.consecutive_32bit.end_offsets[i] - // + strs->data.consecutive_32bit.separator_length * (i + 1 != count); *start = strs->data.consecutive_32bit.start + start_offset; - *length = end_offset - start_offset - strs->data.consecutive_32bit.separator_length * (i + 1 != count); + *length = end_offset - start_offset; *parent_string = strs->data.consecutive_32bit.parent_string; } void str_at_offset_consecutive_64bit(Strs *strs, Py_ssize_t i, Py_ssize_t count, // PyObject **parent_string, char const **start, size_t *length) { uint64_t start_offset = (i == 0) ? 0 : strs->data.consecutive_64bit.end_offsets[i - 1]; - uint64_t end_offset = strs->data.consecutive_64bit.end_offsets[i]; + uint64_t end_offset = strs->data.consecutive_64bit.end_offsets[i] - // + strs->data.consecutive_64bit.separator_length * (i + 1 != count); *start = strs->data.consecutive_64bit.start + start_offset; - *length = end_offset - start_offset - strs->data.consecutive_64bit.separator_length * (i + 1 != count); + *length = end_offset - start_offset; *parent_string = strs->data.consecutive_64bit.parent_string; } @@ -348,7 +385,7 @@ sz_bool_t prepare_strings_for_extension(Strs *strs, size_t new_parents, size_t n #pragma endregion -#pragma region MemoryMappingFile +#pragma region Memory Mapping File static void File_dealloc(File *self) { #if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) @@ -577,6 +614,20 @@ static void Str_dealloc(Str *self) { static PyObject *Str_str(Str *self) { return PyUnicode_FromStringAndSize(self->start, self->length); } +static PyObject *Str_repr(Str *self) { + // Interestingly, known-length string formatting only works in Python 3.12 and later. + // https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_FromFormat + if (PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 12) + return PyUnicode_FromFormat("sz.Str('%.*s')", (int)self->length, self->start); + else { + // Use a simpler formatting rule for older versions + PyObject *str_obj = PyUnicode_FromStringAndSize(self->start, self->length); + PyObject *result = PyUnicode_FromFormat("sz.Str('%U')", str_obj); + Py_DECREF(str_obj); + return result; + } +} + static Py_hash_t Str_hash(Str *self) { return (Py_hash_t)sz_hash(self->start, self->length); } static PyObject *Str_like_hash(PyObject *self, PyObject *args, PyObject *kwargs) { @@ -588,11 +639,11 @@ static PyObject *Str_like_hash(PyObject *self, PyObject *args, PyObject *kwargs) return NULL; } - PyObject *text_obj = is_member ? self : PyTuple_GET_ITEM(args, 0); + PyObject *text_object = is_member ? self : PyTuple_GET_ITEM(args, 0); sz_string_view_t text; // Validate and convert `text` - if (!export_string_like(text_obj, &text.start, &text.length)) { + if (!export_string_like(text_object, &text.start, &text.length)) { PyErr_SetString(PyExc_TypeError, "The text argument must be string-like"); return NULL; } @@ -680,15 +731,20 @@ static void Str_releasebuffer(PyObject *_, Py_buffer *view) { // https://docs.python.org/3/c-api/typeobj.html#c.PyBufferProcs.bf_releasebuffer } -static int Str_in(Str *self, PyObject *arg) { +/** + * @brief Will be called by the `PySequence_Contains` to check presence of a substring. + * @return 1 if the string is present, 0 if it is not, -1 in case of error. + * @see Docs: https://docs.python.org/3/c-api/sequence.html#c.PySequence_Contains + */ +static int Str_in(Str *self, PyObject *needle_obj) { - sz_string_view_t needle_struct; - if (!export_string_like(arg, &needle_struct.start, &needle_struct.length)) { + sz_string_view_t needle; + if (!export_string_like(needle_obj, &needle.start, &needle.length)) { PyErr_SetString(PyExc_TypeError, "Unsupported argument type"); return -1; } - return sz_find(self->start, self->length, needle_struct.start, needle_struct.length) != NULL; + return sz_find(self->start, self->length, needle.start, needle.length) != NULL; } static PyObject *Strs_get_tape(Str *self, void *closure) { return NULL; } @@ -887,14 +943,18 @@ static PyObject *Strs_subscript(Strs *self, PyObject *key) { return (PyObject *)result; } -// Will be called by the `PySequence_Contains` -static int Strs_contains(Str *self, PyObject *needle_obj) { +/** + * @brief Will be called by the `PySequence_Contains` to check the presence of a string in array. + * @return 1 if the string is present, 0 if it is not, -1 in case of error. + * @see Docs: https://docs.python.org/3/c-api/sequence.html#c.PySequence_Contains + */ +static int Strs_in(Str *self, PyObject *needle_obj) { // Validate and convert `needle` sz_string_view_t needle; if (!export_string_like(needle_obj, &needle.start, &needle.length)) { PyErr_SetString(PyExc_TypeError, "The needle argument must be string-like"); - return NULL; + return -1; } // Depending on the layout, we will need to use different logic @@ -902,7 +962,7 @@ static int Strs_contains(Str *self, PyObject *needle_obj) { get_string_at_offset_t getter = str_at_offset_getter(self); if (!getter) { PyErr_SetString(PyExc_TypeError, "Unknown Strs kind"); - return NULL; + return -1; } // Time for a full-scan @@ -1112,7 +1172,7 @@ static PyObject *Str_write_to(PyObject *self, PyObject *args, PyObject *kwargs) return NULL; } - PyObject *text_obj = is_member ? self : PyTuple_GET_ITEM(args, 0); + PyObject *text_object = is_member ? self : PyTuple_GET_ITEM(args, 0); PyObject *path_obj = PyTuple_GET_ITEM(args, !is_member + 0); // Parse keyword arguments @@ -1125,7 +1185,7 @@ static PyObject *Str_write_to(PyObject *self, PyObject *args, PyObject *kwargs) sz_string_view_t path; // Validate and convert `text` and `path` - if (!export_string_like(text_obj, &text.start, &text.length) || + if (!export_string_like(text_object, &text.start, &text.length) || !export_string_like(path_obj, &path.start, &path.length)) { PyErr_SetString(PyExc_TypeError, "Text and path must be string-like"); return NULL; @@ -1181,7 +1241,7 @@ static PyObject *Str_offset_within(PyObject *self, PyObject *args, PyObject *kwa } PyObject *slice_obj = is_member ? self : PyTuple_GET_ITEM(args, 0); - PyObject *text_obj = PyTuple_GET_ITEM(args, !is_member + 0); + PyObject *text_object = PyTuple_GET_ITEM(args, !is_member + 0); // Parse keyword arguments if (kwargs) { @@ -1193,7 +1253,7 @@ static PyObject *Str_offset_within(PyObject *self, PyObject *args, PyObject *kwa sz_string_view_t slice; // Validate and convert `text` and `slice` - if (!export_string_like(text_obj, &text.start, &text.length) || + if (!export_string_like(text_object, &text.start, &text.length) || !export_string_like(slice_obj, &slice.start, &slice.length)) { PyErr_SetString(PyExc_TypeError, "Text and slice must be string-like"); return NULL; @@ -1212,7 +1272,7 @@ static PyObject *Str_offset_within(PyObject *self, PyObject *args, PyObject *kwa * @return 1 on success, 0 on failure. */ static int _Str_find_implementation_( // - PyObject *self, PyObject *args, PyObject *kwargs, sz_find_t finder, Py_ssize_t *offset_out, + PyObject *self, PyObject *args, PyObject *kwargs, sz_find_t finder, sz_bool_t is_reverse, Py_ssize_t *offset_out, sz_string_view_t *haystack_out, sz_string_view_t *needle_out) { int is_member = self != NULL && PyObject_TypeCheck(self, &StrType); @@ -1278,6 +1338,14 @@ static int _Str_find_implementation_( // haystack.start += normalized_offset; haystack.length = normalized_length; + // If the needle length is zero, the result is start index in normal order or end index in reverse order + if (needle.length == 0) { + *offset_out = !is_reverse ? normalized_offset : (normalized_offset + normalized_length); + *haystack_out = haystack; + *needle_out = needle; + return 1; + } + // Perform contains operation sz_cptr_t match = finder(haystack.start, haystack.length, needle.start, needle.length); if (match == NULL) { *offset_out = -1; } @@ -1292,7 +1360,8 @@ static PyObject *Str_contains(PyObject *self, PyObject *args, PyObject *kwargs) Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_find, &signed_offset, &text, &separator)) return NULL; + if (!_Str_find_implementation_(self, args, kwargs, &sz_find, sz_false_k, &signed_offset, &text, &separator)) + return NULL; if (signed_offset == -1) { Py_RETURN_FALSE; } else { Py_RETURN_TRUE; } } @@ -1301,7 +1370,8 @@ static PyObject *Str_find(PyObject *self, PyObject *args, PyObject *kwargs) { Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_find, &signed_offset, &text, &separator)) return NULL; + if (!_Str_find_implementation_(self, args, kwargs, &sz_find, sz_false_k, &signed_offset, &text, &separator)) + return NULL; return PyLong_FromSsize_t(signed_offset); } @@ -1309,7 +1379,8 @@ static PyObject *Str_index(PyObject *self, PyObject *args, PyObject *kwargs) { Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_find, &signed_offset, &text, &separator)) return NULL; + if (!_Str_find_implementation_(self, args, kwargs, &sz_find, sz_false_k, &signed_offset, &text, &separator)) + return NULL; if (signed_offset == -1) { PyErr_SetString(PyExc_ValueError, "substring not found"); return NULL; @@ -1321,7 +1392,8 @@ static PyObject *Str_rfind(PyObject *self, PyObject *args, PyObject *kwargs) { Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind, &signed_offset, &text, &separator)) return NULL; + if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind, sz_true_k, &signed_offset, &text, &separator)) + return NULL; return PyLong_FromSsize_t(signed_offset); } @@ -1329,7 +1401,8 @@ static PyObject *Str_rindex(PyObject *self, PyObject *args, PyObject *kwargs) { Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind, &signed_offset, &text, &separator)) return NULL; + if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind, sz_true_k, &signed_offset, &text, &separator)) + return NULL; if (signed_offset == -1) { PyErr_SetString(PyExc_ValueError, "substring not found"); return NULL; @@ -1337,14 +1410,22 @@ static PyObject *Str_rindex(PyObject *self, PyObject *args, PyObject *kwargs) { return PyLong_FromSsize_t(signed_offset); } -static PyObject *_Str_partition_implementation(PyObject *self, PyObject *args, PyObject *kwargs, sz_find_t finder) { +static PyObject *_Str_partition_implementation(PyObject *self, PyObject *args, PyObject *kwargs, sz_find_t finder, + sz_bool_t is_reverse) { Py_ssize_t separator_index; sz_string_view_t text; sz_string_view_t separator; PyObject *result_tuple; // Use _Str_find_implementation_ to get the index of the separator - if (!_Str_find_implementation_(self, args, kwargs, finder, &separator_index, &text, &separator)) return NULL; + if (!_Str_find_implementation_(self, args, kwargs, finder, is_reverse, &separator_index, &text, &separator)) + return NULL; + + // If the separator length is zero, we must raise a `ValueError` + if (separator.length == 0) { + PyErr_SetString(PyExc_ValueError, "empty separator"); + return NULL; + } // If separator is not found, return a tuple (self, "", "") if (separator_index == -1) { @@ -1384,11 +1465,11 @@ static PyObject *_Str_partition_implementation(PyObject *self, PyObject *args, P } static PyObject *Str_partition(PyObject *self, PyObject *args, PyObject *kwargs) { - return _Str_partition_implementation(self, args, kwargs, &sz_find); + return _Str_partition_implementation(self, args, kwargs, &sz_find, sz_false_k); } static PyObject *Str_rpartition(PyObject *self, PyObject *args, PyObject *kwargs) { - return _Str_partition_implementation(self, args, kwargs, &sz_rfind); + return _Str_partition_implementation(self, args, kwargs, &sz_rfind, sz_true_k); } static PyObject *Str_count(PyObject *self, PyObject *args, PyObject *kwargs) { @@ -1767,7 +1848,8 @@ static PyObject *Str_find_first_of(PyObject *self, PyObject *args, PyObject *kwa Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_find_char_from, &signed_offset, &text, &separator)) + if (!_Str_find_implementation_(self, args, kwargs, &sz_find_char_from, sz_false_k, &signed_offset, &text, + &separator)) return NULL; return PyLong_FromSsize_t(signed_offset); } @@ -1776,7 +1858,8 @@ static PyObject *Str_find_first_not_of(PyObject *self, PyObject *args, PyObject Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_find_char_not_from, &signed_offset, &text, &separator)) + if (!_Str_find_implementation_(self, args, kwargs, &sz_find_char_not_from, sz_false_k, &signed_offset, &text, + &separator)) return NULL; return PyLong_FromSsize_t(signed_offset); } @@ -1785,7 +1868,8 @@ static PyObject *Str_find_last_of(PyObject *self, PyObject *args, PyObject *kwar Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind_char_from, &signed_offset, &text, &separator)) + if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind_char_from, sz_true_k, &signed_offset, &text, + &separator)) return NULL; return PyLong_FromSsize_t(signed_offset); } @@ -1794,13 +1878,49 @@ static PyObject *Str_find_last_not_of(PyObject *self, PyObject *args, PyObject * Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind_char_not_from, &signed_offset, &text, &separator)) + if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind_char_not_from, sz_true_k, &signed_offset, &text, + &separator)) return NULL; return PyLong_FromSsize_t(signed_offset); } -static Strs *Str_split_(PyObject *parent_string, sz_string_view_t text, sz_string_view_t separator, int keepseparator, - Py_ssize_t maxsplit) { +/** + * @brief Given parsed split settings, constructs an iterator that would produce that split. + */ +static SplitIterator *Str_split_iter_(PyObject *text_object, PyObject *separator_object, // + sz_string_view_t const text, sz_string_view_t const separator, // + int keepseparator, Py_ssize_t maxsplit, sz_find_t finder, sz_size_t match_length, + sz_bool_t is_reverse) { + + // Create a new `SplitIterator` object + SplitIterator *result_obj = (SplitIterator *)SplitIteratorType.tp_alloc(&SplitIteratorType, 0); + if (result_obj == NULL && PyErr_NoMemory()) return NULL; + + // Set its properties based on the slice + result_obj->text_object = text_object; + result_obj->separator_object = separator_object; + result_obj->text = text; + result_obj->separator = separator; + result_obj->finder = finder; + + result_obj->match_length = match_length; + result_obj->include_match = keepseparator; + result_obj->is_reverse = is_reverse; + result_obj->max_parts = (sz_size_t)maxsplit + 1; + result_obj->reached_tail = 0; + + // Increment the reference count of the parent + Py_INCREF(result_obj->text_object); + Py_XINCREF(result_obj->separator_object); + return result_obj; +} + +/** + * @brief Implements the normal order split logic for both string-delimiters and character sets. + * Produuces one of the consecutive layouts - `STRS_CONSECUTIVE_64` or `STRS_CONSECUTIVE_32`. + */ +static Strs *Str_split_(PyObject *parent_string, sz_string_view_t const text, sz_string_view_t const separator, + int keepseparator, Py_ssize_t maxsplit, sz_find_t finder, sz_size_t match_length) { // Create Strs object Strs *result = (Strs *)PyObject_New(Strs, &StrsType); if (!result) return NULL; @@ -1815,21 +1935,36 @@ static Strs *Str_split_(PyObject *parent_string, sz_string_view_t text, sz_strin result->type = STRS_CONSECUTIVE_64; result->data.consecutive_64bit.start = text.start; result->data.consecutive_64bit.parent_string = parent_string; - result->data.consecutive_64bit.separator_length = !keepseparator * separator.length; + result->data.consecutive_64bit.separator_length = !keepseparator * match_length; } else { bytes_per_offset = 4; result->type = STRS_CONSECUTIVE_32; result->data.consecutive_32bit.start = text.start; result->data.consecutive_32bit.parent_string = parent_string; - result->data.consecutive_32bit.separator_length = !keepseparator * separator.length; + result->data.consecutive_32bit.separator_length = !keepseparator * match_length; } - // Iterate through string, keeping track of the - sz_size_t last_start = 0; - while (last_start <= text.length && offsets_count < maxsplit) { - sz_cptr_t match = sz_find(text.start + last_start, text.length - last_start, separator.start, separator.length); - sz_size_t offset_in_remaining = match ? match - text.start - last_start : text.length - last_start; + sz_bool_t reached_tail = 0; + sz_size_t total_skipped = 0; + sz_size_t max_parts = (sz_size_t)maxsplit + 1; + while (!reached_tail) { + + sz_cptr_t match = + offsets_count + 1 < max_parts + ? finder(text.start + total_skipped, text.length - total_skipped, separator.start, separator.length) + : NULL; + + sz_size_t part_end_offset; + if (match) { + part_end_offset = (match - text.start) + match_length; + total_skipped = part_end_offset; + } + else { + part_end_offset = text.length; + total_skipped = text.length; + reached_tail = 1; + } // Reallocate offsets array if needed if (offsets_count >= offsets_capacity) { @@ -1849,17 +1984,13 @@ static Strs *Str_split_(PyObject *parent_string, sz_string_view_t text, sz_strin } // Export the offset - size_t will_continue = match != NULL; - size_t next_offset = last_start + offset_in_remaining + separator.length * will_continue; - if (text.length >= UINT32_MAX) { ((uint64_t *)offsets_endings)[offsets_count++] = (uint64_t)next_offset; } - else { ((uint32_t *)offsets_endings)[offsets_count++] = (uint32_t)next_offset; } - - // Next time we want to start - last_start = last_start + offset_in_remaining + separator.length; + if (bytes_per_offset == 8) { ((uint64_t *)offsets_endings)[offsets_count] = (uint64_t)part_end_offset; } + else { ((uint32_t *)offsets_endings)[offsets_count] = (uint32_t)part_end_offset; } + offsets_count++; } // Populate the Strs object with the offsets - if (text.length >= UINT32_MAX) { + if (bytes_per_offset == 8) { result->data.consecutive_64bit.end_offsets = offsets_endings; result->data.consecutive_64bit.count = offsets_count; } @@ -1872,7 +2003,92 @@ static Strs *Str_split_(PyObject *parent_string, sz_string_view_t text, sz_strin return result; } -static PyObject *Str_split(PyObject *self, PyObject *args, PyObject *kwargs) { +/** + * @brief Implements the reverse order split logic for both string-delimiters and character sets. + * Unlike the `Str_split_` can't use consecutive layouts and produces a `REAORDERED` one. + */ +static Strs *Str_rsplit_(PyObject *parent_string, sz_string_view_t const text, sz_string_view_t const separator, + int keepseparator, Py_ssize_t maxsplit, sz_find_t finder, sz_size_t match_length) { + // Create Strs object + Strs *result = (Strs *)PyObject_New(Strs, &StrsType); + if (!result) return NULL; + + // Initialize Strs object based on the splitting logic + result->type = STRS_REORDERED; + result->data.reordered.parent_string = parent_string; + result->data.reordered.parts = NULL; + result->data.reordered.count = 0; + + // Keep track of the memory usage + sz_string_view_t *parts = NULL; + sz_size_t parts_capacity = 0; + sz_size_t parts_count = 0; + + sz_bool_t reached_tail = 0; + sz_size_t total_skipped = 0; + sz_size_t max_parts = (sz_size_t)maxsplit + 1; + while (!reached_tail) { + + sz_cptr_t match = parts_count + 1 < max_parts + ? finder(text.start, text.length - total_skipped, separator.start, separator.length) + : NULL; + + // Determine the next part + sz_string_view_t part; + if (match) { + part.start = match + match_length * !keepseparator; + part.length = text.start + text.length - total_skipped - part.start; + total_skipped = text.start + text.length - match; + } + else { + part.start = text.start; + part.length = text.length - total_skipped; + reached_tail = 1; + } + + // Reallocate parts array if needed + if (parts_count >= parts_capacity) { + parts_capacity = (parts_capacity + 1) * 2; + sz_string_view_t *new_parts = (sz_string_view_t *)realloc(parts, parts_capacity * sizeof(sz_string_view_t)); + if (!new_parts) { + if (parts) free(parts); + } + parts = new_parts; + } + + // If the memory allocation has failed - discard the response + if (!parts) { + Py_XDECREF(result); + PyErr_NoMemory(); + return NULL; + } + + // Populate the parts array + parts[parts_count] = part; + parts_count++; + } + + // Python does this weird thing, where the `rsplit` results appear in the same order as `split` + // so we need to reverse the order of elements in the `parts` array. + for (sz_size_t i = 0; i < parts_count / 2; i++) { + sz_string_view_t temp = parts[i]; + parts[i] = parts[parts_count - i - 1]; + parts[parts_count - i - 1] = temp; + } + + result->data.reordered.parts = parts; + result->data.reordered.count = parts_count; + Py_INCREF(parent_string); + return result; +} + +/** + * @brief Proxy routing requests like `Str.split`, `Str.rsplit`, `Str.split_charset` and `Str.rsplit_charset` + * to `Str_split_` and `Str_rsplit_` implementations, parsing function arguments. + */ +static PyObject *Str_split_with_known_callback(PyObject *self, PyObject *args, PyObject *kwargs, // + sz_find_t finder, sz_size_t match_length, // + sz_bool_t is_reverse, sz_bool_t is_lazy_iterator) { // Check minimum arguments int is_member = self != NULL && PyObject_TypeCheck(self, &StrType); Py_ssize_t nargs = PyTuple_Size(args); @@ -1881,8 +2097,8 @@ static PyObject *Str_split(PyObject *self, PyObject *args, PyObject *kwargs) { return NULL; } - PyObject *text_obj = is_member ? self : PyTuple_GET_ITEM(args, 0); - PyObject *separator_obj = nargs > !is_member + 0 ? PyTuple_GET_ITEM(args, !is_member + 0) : NULL; + PyObject *text_object = is_member ? self : PyTuple_GET_ITEM(args, 0); + PyObject *separator_object = nargs > !is_member + 0 ? PyTuple_GET_ITEM(args, !is_member + 0) : NULL; PyObject *maxsplit_obj = nargs > !is_member + 1 ? PyTuple_GET_ITEM(args, !is_member + 1) : NULL; PyObject *keepseparator_obj = nargs > !is_member + 2 ? PyTuple_GET_ITEM(args, !is_member + 2) : NULL; @@ -1890,7 +2106,7 @@ static PyObject *Str_split(PyObject *self, PyObject *args, PyObject *kwargs) { PyObject *key, *value; Py_ssize_t pos = 0; while (PyDict_Next(kwargs, &pos, &key, &value)) { - if (PyUnicode_CompareWithASCIIString(key, "separator") == 0) { separator_obj = value; } + if (PyUnicode_CompareWithASCIIString(key, "separator") == 0) { separator_object = value; } else if (PyUnicode_CompareWithASCIIString(key, "maxsplit") == 0) { maxsplit_obj = value; } else if (PyUnicode_CompareWithASCIIString(key, "keepseparator") == 0) { keepseparator_obj = value; } else if (PyErr_Format(PyExc_TypeError, "Got an unexpected keyword argument '%U'", key)) @@ -1904,21 +2120,27 @@ static PyObject *Str_split(PyObject *self, PyObject *args, PyObject *kwargs) { Py_ssize_t maxsplit; // Validate and convert `text` - if (!export_string_like(text_obj, &text.start, &text.length)) { + if (!export_string_like(text_object, &text.start, &text.length)) { PyErr_SetString(PyExc_TypeError, "The text argument must be string-like"); return NULL; } // Validate and convert `separator` - if (separator_obj) { - if (!export_string_like(separator_obj, &separator.start, &separator.length)) { + if (separator_object) { + if (!export_string_like(separator_object, &separator.start, &separator.length)) { PyErr_SetString(PyExc_TypeError, "The separator argument must be string-like"); return NULL; } + // Raise a `ValueError` if it's length is zero, like the native `str.split` + if (separator.length == 0) { + PyErr_SetString(PyExc_ValueError, "The separator argument must not be empty"); + return NULL; + } + if (match_length == 0) match_length = separator.length; } else { separator.start = " "; - separator.length = 1; + match_length = separator.length = 1; } // Validate and convert `keepseparator` @@ -1941,7 +2163,45 @@ static PyObject *Str_split(PyObject *self, PyObject *args, PyObject *kwargs) { } else { maxsplit = PY_SSIZE_T_MAX; } - return Str_split_(text_obj, text, separator, keepseparator, maxsplit); + // Dispatch the right backend + if (is_lazy_iterator) + return Str_split_iter_(text_object, separator_object, text, separator, // + keepseparator, maxsplit, finder, match_length, is_reverse); + else + return !is_reverse ? Str_split_(text_object, text, separator, keepseparator, maxsplit, finder, match_length) + : Str_rsplit_(text_object, text, separator, keepseparator, maxsplit, finder, match_length); +} + +static PyObject *Str_split(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_find, 0, sz_false_k, sz_false_k); +} + +static PyObject *Str_rsplit(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_rfind, 0, sz_true_k, sz_false_k); +} + +static PyObject *Str_split_charset(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_find_char_from, 1, sz_false_k, sz_false_k); +} + +static PyObject *Str_rsplit_charset(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_rfind_char_from, 1, sz_true_k, sz_false_k); +} + +static PyObject *Str_split_iter(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_find, 0, sz_false_k, sz_true_k); +} + +static PyObject *Str_rsplit_iter(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_rfind, 0, sz_true_k, sz_true_k); +} + +static PyObject *Str_split_charset_iter(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_find_char_from, 1, sz_false_k, sz_true_k); +} + +static PyObject *Str_rsplit_charset_iter(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_rfind_char_from, 1, sz_true_k, sz_true_k); } static PyObject *Str_splitlines(PyObject *self, PyObject *args, PyObject *kwargs) { @@ -1953,7 +2213,7 @@ static PyObject *Str_splitlines(PyObject *self, PyObject *args, PyObject *kwargs return NULL; } - PyObject *text_obj = is_member ? self : PyTuple_GET_ITEM(args, 0); + PyObject *text_object = is_member ? self : PyTuple_GET_ITEM(args, 0); PyObject *keeplinebreaks_obj = nargs > !is_member ? PyTuple_GET_ITEM(args, !is_member) : NULL; PyObject *maxsplit_obj = nargs > !is_member + 1 ? PyTuple_GET_ITEM(args, !is_member + 1) : NULL; @@ -1972,7 +2232,7 @@ static PyObject *Str_splitlines(PyObject *self, PyObject *args, PyObject *kwargs Py_ssize_t maxsplit = PY_SSIZE_T_MAX; // Default value for maxsplit // Validate and convert `text` - if (!export_string_like(text_obj, &text.start, &text.length)) { + if (!export_string_like(text_object, &text.start, &text.length)) { PyErr_SetString(PyExc_TypeError, "The text argument must be string-like"); return NULL; } @@ -2003,7 +2263,7 @@ static PyObject *Str_splitlines(PyObject *self, PyObject *args, PyObject *kwargs sz_string_view_t separator; separator.start = "\n"; separator.length = 1; - return Str_split_(text_obj, text, separator, keeplinebreaks, maxsplit); + return Str_split_(text_object, text, separator, keeplinebreaks, maxsplit, &sz_find, 1); } static PyObject *Str_concat(PyObject *self, PyObject *other) { @@ -2079,15 +2339,16 @@ static PyMethodDef Str_methods[] = { {"splitlines", Str_splitlines, SZ_METHOD_FLAGS, "Split a string by line breaks."}, {"startswith", Str_startswith, SZ_METHOD_FLAGS, "Check if a string starts with a given prefix."}, {"endswith", Str_endswith, SZ_METHOD_FLAGS, "Check if a string ends with a given suffix."}, - {"split", Str_split, SZ_METHOD_FLAGS, "Split a string by a separator."}, // Bidirectional operations {"find", Str_find, SZ_METHOD_FLAGS, "Find the first occurrence of a substring."}, {"index", Str_index, SZ_METHOD_FLAGS, "Find the first occurrence of a substring or raise error if missing."}, {"partition", Str_partition, SZ_METHOD_FLAGS, "Splits string into 3-tuple: before, first match, after."}, + {"split", Str_split, SZ_METHOD_FLAGS, "Split a string by a separator."}, {"rfind", Str_rfind, SZ_METHOD_FLAGS, "Find the last occurrence of a substring."}, {"rindex", Str_rindex, SZ_METHOD_FLAGS, "Find the last occurrence of a substring or raise error if missing."}, {"rpartition", Str_rpartition, SZ_METHOD_FLAGS, "Splits string into 3-tuple: before, last match, after."}, + {"rsplit", Str_rsplit, SZ_METHOD_FLAGS, "Split a string by a separator in reverse order."}, // Edit distance extensions {"hamming_distance", Str_hamming_distance, SZ_METHOD_FLAGS, @@ -2110,6 +2371,18 @@ static PyMethodDef Str_methods[] = { "Finds the first occurrence of a character not present in another string."}, {"find_last_not_of", Str_find_last_not_of, SZ_METHOD_FLAGS, "Finds the last occurrence of a character not present in another string."}, + {"split_charset", Str_split_charset, SZ_METHOD_FLAGS, "Split a string by a set of character separators."}, + {"rsplit_charset", Str_rsplit_charset, SZ_METHOD_FLAGS, + "Split a string by a set of character separators in reverse order."}, + + // Lazily evaluated iterators + {"split_iter", Str_split_iter, SZ_METHOD_FLAGS, "Create an iterator for splitting a string by a separator."}, + {"rsplit_iter", Str_rsplit_iter, SZ_METHOD_FLAGS, + "Create an iterator for splitting a string by a separator in reverse order."}, + {"split_charset_iter", Str_split_charset_iter, SZ_METHOD_FLAGS, + "Create an iterator for splitting a string by a set of character separators."}, + {"rsplit_charset_iter", Str_rsplit_charset_iter, SZ_METHOD_FLAGS, + "Create an iterator for splitting a string by a set of character separators in reverse order."}, // Dealing with larger-than-memory datasets {"offset_within", Str_offset_within, SZ_METHOD_FLAGS, @@ -2128,6 +2401,7 @@ static PyTypeObject StrType = { .tp_dealloc = Str_dealloc, .tp_hash = Str_hash, .tp_richcompare = Str_richcompare, + .tp_repr = (reprfunc)Str_repr, .tp_str = Str_str, .tp_methods = Str_methods, .tp_as_sequence = &Str_as_sequence, @@ -2139,7 +2413,84 @@ static PyTypeObject StrType = { #pragma endregion -#pragma regions Strs +#pragma region Split Iterator + +static PyObject *SplitIteratorType_next(SplitIterator *self) { + // No more data to split + if (self->reached_tail) return NULL; + + // Create a new `Str` object + Str *result_obj = (Str *)StrType.tp_alloc(&StrType, 0); + if (result_obj == NULL && PyErr_NoMemory()) return NULL; + + sz_cptr_t result_start; + sz_size_t result_length; + + // Find the next needle + sz_cptr_t found = + self->max_parts > 1 // + ? self->finder(self->text.start, self->text.length, self->separator.start, self->separator.length) + : NULL; + + // We've reached the end of the string + if (found == NULL) { + result_start = self->text.start; + result_length = self->text.length; + self->text.length = 0; + self->reached_tail = 1; + self->max_parts = 0; + } + else { + if (self->is_reverse) { + result_start = found + self->match_length * !self->include_match; + result_length = self->text.start + self->text.length - result_start; + self->text.length = found - self->text.start; + } + else { + result_start = self->text.start; + result_length = found - self->text.start; + result_length += self->match_length * self->include_match; + self->text.start = found + self->match_length; + self->text.length -= result_length + self->match_length; + } + self->max_parts--; + } + + // Set its properties based on the slice + result_obj->start = result_start; + result_obj->length = result_length; + result_obj->parent = self->text_object; + + // Increment the reference count of the parent + Py_INCREF(self->text_object); + return (PyObject *)result_obj; +} + +static void SplitIteratorType_dealloc(SplitIterator *self) { + Py_XDECREF(self->text_object); + Py_XDECREF(self->separator_object); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *SplitIteratorType_iter(PyObject *self) { + Py_INCREF(self); // Iterator should return itself in __iter__. + return self; +} + +static PyTypeObject SplitIteratorType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "stringzilla.SplitIterator", + .tp_basicsize = sizeof(SplitIterator), + .tp_itemsize = 0, + .tp_dealloc = (destructor)SplitIteratorType_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Text-splitting iterator", + .tp_iter = SplitIteratorType_iter, + .tp_iternext = (iternextfunc)SplitIteratorType_next, +}; + +#pragma endregion + +#pragma region Strs static PyObject *Strs_shuffle(Strs *self, PyObject *args, PyObject *kwargs) { unsigned int seed = time(NULL); // Default seed @@ -2464,10 +2815,25 @@ static PyObject *Strs_sample(Strs *self, PyObject *args, PyObject *kwargs) { return result; } +/** + * @brief Array to string conversion method, that concatenates all the strings in the array. + */ +static PyObject *Strs_str(Strs *self) { + // This is just an example, adapt it to your needs + // For instance, you could iterate over your Strs and concatenate them into a single string + return PyUnicode_FromFormat("<%s object at %p>", Py_TYPE(self)->tp_name, self); +} + +static PyObject *Strs_repr(Strs *self) { + // This is just an example, adapt it to your needs + // For instance, you could iterate over your Strs and concatenate them into a single string + return PyUnicode_FromFormat("<%s object at %p>", Py_TYPE(self)->tp_name, self); +} + static PySequenceMethods Strs_as_sequence = { - .sq_length = Strs_len, // - .sq_item = Strs_getitem, // - .sq_contains = Strs_contains, // + .sq_length = Strs_len, // + .sq_item = Strs_getitem, // + .sq_contains = Strs_in, // }; static PyMappingMethods Strs_as_mapping = { @@ -2507,6 +2873,8 @@ static PyTypeObject StrsType = { .tp_as_mapping = &Strs_as_mapping, .tp_getset = Strs_getsetters, .tp_richcompare = Strs_richcompare, + .tp_repr = (reprfunc)Strs_repr, + .tp_str = (reprfunc)Strs_str, }; #pragma endregion @@ -2524,15 +2892,16 @@ static PyMethodDef stringzilla_methods[] = { {"splitlines", Str_splitlines, SZ_METHOD_FLAGS, "Split a string by line breaks."}, {"startswith", Str_startswith, SZ_METHOD_FLAGS, "Check if a string starts with a given prefix."}, {"endswith", Str_endswith, SZ_METHOD_FLAGS, "Check if a string ends with a given suffix."}, - {"split", Str_split, SZ_METHOD_FLAGS, "Split a string by a separator."}, // Bidirectional operations {"find", Str_find, SZ_METHOD_FLAGS, "Find the first occurrence of a substring."}, {"index", Str_index, SZ_METHOD_FLAGS, "Find the first occurrence of a substring or raise error if missing."}, {"partition", Str_partition, SZ_METHOD_FLAGS, "Splits string into 3-tuple: before, first match, after."}, + {"split", Str_split, SZ_METHOD_FLAGS, "Split a string by a separator."}, {"rfind", Str_rfind, SZ_METHOD_FLAGS, "Find the last occurrence of a substring."}, {"rindex", Str_rindex, SZ_METHOD_FLAGS, "Find the last occurrence of a substring or raise error if missing."}, {"rpartition", Str_rpartition, SZ_METHOD_FLAGS, "Splits string into 3-tuple: before, last match, after."}, + {"rsplit", Str_rsplit, SZ_METHOD_FLAGS, "Split a string by a separator in reverse order."}, // Edit distance extensions {"hamming_distance", Str_hamming_distance, SZ_METHOD_FLAGS, @@ -2555,6 +2924,23 @@ static PyMethodDef stringzilla_methods[] = { "Finds the first occurrence of a character not present in another string."}, {"find_last_not_of", Str_find_last_not_of, SZ_METHOD_FLAGS, "Finds the last occurrence of a character not present in another string."}, + {"split_charset", Str_split_charset, SZ_METHOD_FLAGS, "Split a string by a set of character separators."}, + {"rsplit_charset", Str_rsplit_charset, SZ_METHOD_FLAGS, + "Split a string by a set of character separators in reverse order."}, + + // Lazily evaluated iterators + {"split_iter", Str_split_iter, SZ_METHOD_FLAGS, "Create an iterator for splitting a string by a separator."}, + {"rsplit_iter", Str_rsplit_iter, SZ_METHOD_FLAGS, + "Create an iterator for splitting a string by a separator in reverse order."}, + {"split_charset_iter", Str_split_charset_iter, SZ_METHOD_FLAGS, + "Create an iterator for splitting a string by a set of character separators."}, + {"rsplit_charset_iter", Str_rsplit_charset_iter, SZ_METHOD_FLAGS, + "Create an iterator for splitting a string by a set of character separators in reverse order."}, + + // Dealing with larger-than-memory datasets + {"offset_within", Str_offset_within, SZ_METHOD_FLAGS, + "Return the raw byte offset of one binary string within another."}, + {"write_to", Str_write_to, SZ_METHOD_FLAGS, "Return the raw byte offset of one binary string within another."}, // Global unary extensions {"hash", Str_like_hash, SZ_METHOD_FLAGS, "Hash a string or a byte-array."}, @@ -2579,6 +2965,7 @@ PyMODINIT_FUNC PyInit_stringzilla(void) { if (PyType_Ready(&StrType) < 0) return NULL; if (PyType_Ready(&FileType) < 0) return NULL; if (PyType_Ready(&StrsType) < 0) return NULL; + if (PyType_Ready(&SplitIteratorType) < 0) return NULL; m = PyModule_Create(&stringzilla_module); if (m == NULL) return NULL; @@ -2632,6 +3019,16 @@ PyMODINIT_FUNC PyInit_stringzilla(void) { return NULL; } + Py_INCREF(&SplitIteratorType); + if (PyModule_AddObject(m, "SplitIterator", (PyObject *)&SplitIteratorType) < 0) { + Py_XDECREF(&SplitIteratorType); + Py_XDECREF(&StrsType); + Py_XDECREF(&FileType); + Py_XDECREF(&StrType); + Py_XDECREF(m); + return NULL; + } + // Initialize temporary_memory, if needed temporary_memory.start = malloc(4096); temporary_memory.length = 4096 * (temporary_memory.start != NULL); diff --git a/scripts/bench.ipynb b/scripts/bench.ipynb index 533be945..777a3093 100644 --- a/scripts/bench.ipynb +++ b/scripts/bench.ipynb @@ -2,28 +2,29 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "File ‘../leipzig1M.txt’ already there; not retrieving.\n" - ] - } - ], + "outputs": [], "source": [ "!wget --no-clobber -O ../leipzig1M.txt https://introcs.cs.princeton.edu/python/42sort/leipzig1m.txt" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install stringzilla memory_profiler" + ] + }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "import stringzilla as sz" + "%load_ext memory_profiler" ] }, { @@ -32,466 +33,425 @@ "metadata": {}, "outputs": [], "source": [ - "pythonic_str: str = open(\"../leipzig1M.txt\", \"r\").read()\n", - "sz_str = sz.Str(pythonic_str)\n", - "pattern = \"the\"" + "import stringzilla as sz" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "129,644,797, 129,644,797\n" - ] - } - ], + "outputs": [], "source": [ - "print(f\"{len(pythonic_str):,}, {len(sz_str):,}\")" + "pythonic_str: str = open(\"../leipzig1M.txt\", \"r\").read()\n", + "sz_str = sz.Str(pythonic_str)\n", + "pattern = \"the\"" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(1000000, 1000000)" + "[sz.Str('A'),\n", + " sz.Str('rebel'),\n", + " sz.Str('statement'),\n", + " sz.Str('sent'),\n", + " sz.Str('to')]" ] }, - "execution_count": 7, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "pythonic_str.count(\"\\n\"), sz_str.count(\"\\n\")" + "import itertools\n", + "top5 = itertools.islice(sz_str.split_iter(\" \"), 5) # grab the first five words\n", + "top5 = list(top5)\n", + "top5" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 5, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "129,644,797 characters taking 129,644,797 bytes\n" + ] + } + ], "source": [ - "## Throughput" + "print(f\"{len(pythonic_str):,} characters taking {len(sz_str):,} bytes\")" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "455 ms ± 23.6 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)\n" + "Both libraries report the same number of lines: 1,000,000\n" ] } ], "source": [ - "%%timeit -n 1 -r 10\n", - "sorted(pythonic_str.splitlines())" + "python_lines_count = pythonic_str.count(\"\\n\")\n", + "sz_lines_count = sz_str.count(\"\\n\")\n", + "assert python_lines_count == sz_lines_count\n", + "print(f\"Both libraries report the same number of lines: {python_lines_count:,}\")" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "455 ms ± 17.1 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)\n" + "Total of 20,191,473 words of average length ~6.42 characters bytes or ~6.42 bytes\n" ] } ], "source": [ - "%%timeit -n 1 -r 10\n", - "sz_str.splitlines().sort()" + "count_words = pythonic_str.count(\" \")\n", + "print(f\"Total of {count_words:,} words of average length ~{len(pythonic_str) / count_words:.2f} characters bytes or ~{len(sz_str) / count_words:.2f} bytes\")" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "132 ms ± 13 ms per loop (mean ± std. dev. of 100 runs, 1 loop each)\n" + "1.33 s ± 4.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", + "651 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", + "465 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ - "%%timeit -n 1 -r 100\n", - "pythonic_str.count(pattern)" + "%timeit sum(1 for _ in pythonic_str.split())\n", + "%timeit sum(1 for _ in sz_str.split())\n", + "%timeit sum(1 for _ in sz_str.split_iter())" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "33.1 ms ± 7.74 ms per loop (mean ± std. dev. of 100 runs, 1 loop each)\n" + "peak memory: 2235.44 MiB, increment: 1422.00 MiB\n", + "peak memory: 890.62 MiB, increment: 77.00 MiB\n", + "peak memory: 890.62 MiB, increment: 0.00 MiB\n" ] } ], "source": [ - "%%timeit -n 1 -r 100\n", - "sz_str.count(pattern)" + "%memit sum(1 for _ in pythonic_str.split())\n", + "%memit sum(1 for _ in sz_str.split())\n", + "%memit sum(1 for _ in sz_str.split_iter())" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 11, "metadata": {}, + "outputs": [], "source": [ - "## Latency" + "import random" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "30.1 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" + "143 ms ± 7.21 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)\n" ] } ], "source": [ - "%%timeit -n 1 -r 1\n", - "hash(pythonic_str)" + "%%timeit -n 1 -r 10\n", + "random.choices(pythonic_str.splitlines(), k=1000)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "21.5 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" + "25.1 ms ± 481 µs per loop (mean ± std. dev. of 10 runs, 1 loop each)\n" ] } ], "source": [ - "%%timeit -n 1 -r 1\n", - "hash(sz_str)" + "%%timeit -n 1 -r 10\n", + "sz_str.splitlines().sample(1000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Throughput" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1.23 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" + "506 ms ± 12.5 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)\n" ] } ], "source": [ - "%%timeit -n 1 -r 1\n", - "pythonic_str.find(\" \")" + "%%timeit -n 1 -r 10\n", + "sorted(pythonic_str.splitlines())" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "3.4 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" + "382 ms ± 11.1 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)\n" ] } ], "source": [ - "%%timeit -n 1 -r 1\n", - "sz_str.find(\" \")" + "%%timeit -n 1 -r 10\n", + "sz_str.splitlines().sort()" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "87.3 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" + "144 ms ± 4 ms per loop (mean ± std. dev. of 100 runs, 1 loop each)\n" ] } ], "source": [ - "%%timeit -n 1 -r 1\n", - "pythonic_str.partition(\" \")" + "%%timeit -n 1 -r 100\n", + "pythonic_str.count(pattern)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "18.3 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" + "31.2 ms ± 1.06 ms per loop (mean ± std. dev. of 100 runs, 1 loop each)\n" ] } ], "source": [ - "%%timeit -n 1 -r 1\n", - "sz_str.partition(\" \")" + "%%timeit -n 1 -r 100\n", + "sz_str.count(pattern)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Sequences" + "## Latency" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" + "28.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 1\n", - "pythonic_str.split(\" \").sort()" + "hash(pythonic_str)" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "9.19 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" + "365 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 1\n", - "sz_str.split(\" \").sort()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Edit Distance" + "hash(sz_str)" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: python-Levenshtein in /home/av/miniconda3/lib/python3.11/site-packages (0.23.0)\n", - "Requirement already satisfied: Levenshtein==0.23.0 in /home/av/miniconda3/lib/python3.11/site-packages (from python-Levenshtein) (0.23.0)\n", - "Requirement already satisfied: rapidfuzz<4.0.0,>=3.1.0 in /home/av/miniconda3/lib/python3.11/site-packages (from Levenshtein==0.23.0->python-Levenshtein) (3.5.2)\n", - "Requirement already satisfied: levenshtein in /home/av/miniconda3/lib/python3.11/site-packages (0.23.0)\n", - "Requirement already satisfied: rapidfuzz<4.0.0,>=3.1.0 in /home/av/miniconda3/lib/python3.11/site-packages (from levenshtein) (3.5.2)\n", - "Requirement already satisfied: jellyfish in /home/av/miniconda3/lib/python3.11/site-packages (1.0.3)\n", - "Requirement already satisfied: editdistance in /home/av/miniconda3/lib/python3.11/site-packages (0.6.2)\n", - "Collecting distance\n", - " Downloading Distance-0.1.3.tar.gz (180 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m180.3/180.3 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", - "\u001b[?25hBuilding wheels for collected packages: distance\n", - " Building wheel for distance (setup.py) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for distance: filename=Distance-0.1.3-py3-none-any.whl size=16258 sha256=b688ad5c13aada5f4d13ee0844df820e9f6260d94ac0456de71a70d11872ebf4\n", - " Stored in directory: /home/av/.cache/pip/wheels/fb/cd/9c/3ab5d666e3bcacc58900b10959edd3816cc9557c7337986322\n", - "Successfully built distance\n", - "Installing collected packages: distance\n", - "Successfully installed distance-0.1.3\n", - "Requirement already satisfied: polyleven in /home/av/miniconda3/lib/python3.11/site-packages (0.8)\n" + "2.34 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" ] } ], "source": [ - "!pip install python-Levenshtein # 4.8 M/mo: https://github.com/maxbachmann/python-Levenshtein\n", - "!pip install levenshtein # 4.2 M/mo: https://github.com/maxbachmann/Levenshtein\n", - "!pip install jellyfish # 2.3 M/mo: https://github.com/jamesturk/jellyfish/\n", - "!pip install editdistance # 700 k/mo: https://github.com/roy-ht/editdistance\n", - "!pip install distance # 160 k/mo: https://github.com/doukremt/distance\n", - "!pip install polyleven # 34 k/mo: https://github.com/fujimotos/polyleven" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "words = pythonic_str.split(\" \")" + "%%timeit -n 1 -r 1\n", + "pythonic_str.find(\" \")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "4.5 s ± 55.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - }, - { - "ename": "", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." + "4.37 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" ] } ], "source": [ - "%%timeit\n", - "for word in words:\n", - " sz.levenshtein(word, \"rebel\")\n", - " sz.levenshtein(word, \"statement\")\n", - " sz.levenshtein(word, \"sent\")" + "%%timeit -n 1 -r 1\n", + "sz_str.find(\" \")" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "57.5 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" + ] + } + ], "source": [ - "import polyleven as pl" + "%%timeit -n 1 -r 1\n", + "pythonic_str.partition(\" \")" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "4.49 s ± 105 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "4.65 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" ] } ], "source": [ - "%%timeit\n", - "for word in words:\n", - " pl.levenshtein(word, \"rebel\", 100)\n", - " pl.levenshtein(word, \"statement\", 100)\n", - " pl.levenshtein(word, \"sent\", 100)" + "%%timeit -n 1 -r 1\n", + "sz_str.partition(\" \")" ] }, { - "cell_type": "code", - "execution_count": 15, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "import editdistance as ed" + "## Sequences" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "24.9 s ± 300 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "6.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" ] } ], "source": [ - "%%timeit\n", - "for word in words:\n", - " ed.eval(word, \"rebel\")\n", - " ed.eval(word, \"statement\")\n", - " ed.eval(word, \"sent\")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "import jellyfish as jf" + "%%timeit -n 1 -r 1\n", + "pythonic_str.split(\" \").sort()" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "21.8 s ± 390 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "8.86 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" ] } ], "source": [ - "%%timeit\n", - "for word in words:\n", - " jf.levenshtein_distance(word, \"rebel\")\n", - " jf.levenshtein_distance(word, \"statement\")\n", - " jf.levenshtein_distance(word, \"sent\")" + "%%timeit -n 1 -r 1\n", + "sz_str.split(\" \").sort()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/scripts/test.py b/scripts/test.py index 01ed4e2b..be2aa719 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -95,45 +95,6 @@ def test_unit_str_rich_comparisons(): assert s2[-2:] == "bb" -def test_unit_strs_rich_comparisons(): - arr: Strs = Str("a b c d e f g h").split() - - # Test against another Strs object - identical_arr: Strs = Str("a b c d e f g h").split() - different_arr: Strs = Str("a b c d e f g i").split() - shorter_arr: Strs = Str("a b c d e").split() - longer_arr: Strs = Str("a b c d e f g h i j").split() - - assert arr == identical_arr - assert arr != different_arr - assert arr != shorter_arr - assert arr != longer_arr - assert shorter_arr < arr - assert longer_arr > arr - - # Test against a Python list and a tuple - list_equal = ["a", "b", "c", "d", "e", "f", "g", "h"] - list_different = ["a", "b", "c", "d", "x", "f", "g", "h"] - tuple_equal = ("a", "b", "c", "d", "e", "f", "g", "h") - tuple_different = ("a", "b", "c", "d", "e", "f", "g", "i") - - assert arr == list_equal - assert arr != list_different - assert arr == tuple_equal - assert arr != tuple_different - - # Test against a generator of unknown length - generator_equal = (x for x in "a b c d e f g h".split()) - generator_different = (x for x in "a b c d e f g i".split()) - generator_shorter = (x for x in "a b c d e".split()) - generator_longer = (x for x in "a b c d e f g h i j".split()) - - assert arr == generator_equal - assert arr != generator_different - assert arr != generator_shorter - assert arr != generator_longer - - @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") def test_unit_buffer_protocol(): my_str = Str("hello") @@ -144,21 +105,127 @@ def test_unit_buffer_protocol(): def test_unit_split(): - native = "token1\ntoken2\ntoken3" + native = "line1\nline2\nline3" big = Str(native) + + # Splitting using a string + lines = sz.split(big, "\n") + assert lines == ["line1", "line2", "line3"] + + lines = sz.rsplit(big, "\n") + assert lines == ["line1", "line2", "line3"] + + lines = sz.split(big, "\n", keepseparator=True) + assert lines == ["line1\n", "line2\n", "line3"] + + letters = sz.split("a b c d") + assert letters == ["a", "b", "c", "d"] + + # Splitting using character sets + letters = sz.split_charset("a b_c d", " _") + assert letters == ["a", "b", "c", "d"] + + letters = sz.rsplit_charset("a b_c d", " _") + assert letters == ["a", "b", "c", "d"] + + # Check for equivalence with native Python strings for newline separators assert native.splitlines() == list(big.splitlines()) assert native.splitlines(True) == list(big.splitlines(keeplinebreaks=True)) - assert native.split("token3") == list(big.split("token3")) - words = sz.split(big, "\n") - assert len(words) == 3 - assert str(words[0]) == "token1" - assert str(words[2]) == "token3" + # Check for equivalence with native Python strings, including boundary conditions + assert native.split("line1") == list(big.split("line1")) + assert native.split("line3") == list(big.split("line3")) + assert native.split("\n", maxsplit=0) == list(big.split("\n", maxsplit=0)) + assert native.split("\n", maxsplit=1) == list(big.split("\n", maxsplit=1)) + assert native.split("\n", maxsplit=2) == list(big.split("\n", maxsplit=2)) + assert native.split("\n", maxsplit=3) == list(big.split("\n", maxsplit=3)) + assert native.split("\n", maxsplit=4) == list(big.split("\n", maxsplit=4)) + + # Check for equivalence with native Python strings in reverse order, including boundary conditions + assert native.rsplit("line1") == list(big.rsplit("line1")) + assert native.rsplit("line3") == list(big.rsplit("line3")) + assert native.rsplit("\n", maxsplit=0) == list(big.rsplit("\n", maxsplit=0)) + assert native.rsplit("\n", maxsplit=1) == list(big.rsplit("\n", maxsplit=1)) + assert native.rsplit("\n", maxsplit=2) == list(big.rsplit("\n", maxsplit=2)) + assert native.rsplit("\n", maxsplit=3) == list(big.rsplit("\n", maxsplit=3)) + assert native.rsplit("\n", maxsplit=4) == list(big.rsplit("\n", maxsplit=4)) + + # If the passed separator is an empty string, the library must raise a `ValueError` + with pytest.raises(ValueError): + sz.split(big, "") + with pytest.raises(ValueError): + sz.rsplit(big, "") + with pytest.raises(ValueError): + sz.split_charset(big, "") + with pytest.raises(ValueError): + sz.rsplit_charset(big, "") + + +def test_unit_split_iterators(): + """ + Test the iterator-based split methods. + This is slightly different from `split` and `rsplit` in that it returns an iterator instead of a list. + Moreover, the native `rsplit` and even `rsplit_charset` report results in the identical order to `split` + and `split_charset`. Here `rsplit_iter` reports elements in the reverse order, compared to `split_iter`. + """ + native = "line1\nline2\nline3" + big = Str(native) - parts = sz.split(big, "\n", keepseparator=True) - assert len(parts) == 3 - assert str(parts[0]) == "token1\n" - assert str(parts[2]) == "token3" + # Splitting using a string + lines = list(sz.split_iter(big, "\n")) + assert lines == ["line1", "line2", "line3"] + + lines = list(sz.rsplit_iter(big, "\n")) + assert lines == ["line3", "line2", "line1"] + + lines = list(sz.split_iter(big, "\n", keepseparator=True)) + assert lines == ["line1\n", "line2\n", "line3"] + + lines = list(sz.rsplit_iter(big, "\n", keepseparator=True)) + assert lines == ["\nline3", "\nline2", "line1"] + + letters = list(sz.split_iter("a b c d")) + assert letters == ["a", "b", "c", "d"] + + # Splitting using character sets + letters = list(sz.split_charset_iter("a-b_c-d", "-_")) + assert letters == ["a", "b", "c", "d"] + + letters = list(sz.rsplit_charset_iter("a-b_c-d", "-_")) + assert letters == ["d", "c", "b", "a"] + + # Check for equivalence with native Python strings, including boundary conditions + assert native.split("line1") == list(big.split_iter("line1")) + assert native.split("line3") == list(big.split_iter("line3")) + assert native.split("\n", maxsplit=0) == list(big.split_iter("\n", maxsplit=0)) + assert native.split("\n", maxsplit=1) == list(big.split_iter("\n", maxsplit=1)) + assert native.split("\n", maxsplit=2) == list(big.split_iter("\n", maxsplit=2)) + assert native.split("\n", maxsplit=3) == list(big.split_iter("\n", maxsplit=3)) + assert native.split("\n", maxsplit=4) == list(big.split_iter("\n", maxsplit=4)) + + def rlist(seq): + seq = list(seq) + seq.reverse() + return seq + + # Check for equivalence with native Python strings in reverse order, including boundary conditions + assert native.rsplit("line1") == rlist(big.rsplit_iter("line1")) + assert native.rsplit("line3") == rlist(big.rsplit_iter("line3")) + assert native.rsplit("\n", maxsplit=0) == rlist(big.rsplit_iter("\n", maxsplit=0)) + assert native.rsplit("\n", maxsplit=1) == rlist(big.rsplit_iter("\n", maxsplit=1)) + assert native.rsplit("\n", maxsplit=2) == rlist(big.rsplit_iter("\n", maxsplit=2)) + assert native.rsplit("\n", maxsplit=3) == rlist(big.rsplit_iter("\n", maxsplit=3)) + assert native.rsplit("\n", maxsplit=4) == rlist(big.rsplit_iter("\n", maxsplit=4)) + + # If the passed separator is an empty string, the library must raise a `ValueError` + with pytest.raises(ValueError): + sz.split_iter(big, "") + with pytest.raises(ValueError): + sz.rsplit_iter(big, "") + with pytest.raises(ValueError): + sz.split_charset_iter(big, "") + with pytest.raises(ValueError): + sz.rsplit_charset_iter(big, "") def test_unit_strs_sequence(): @@ -197,6 +264,45 @@ def test_unit_slicing(): assert big[:-3] == "abc" +def test_unit_strs_rich_comparisons(): + arr: Strs = Str("a b c d e f g h").split() + + # Test against another Strs object + identical_arr: Strs = Str("a b c d e f g h").split() + different_arr: Strs = Str("a b c d e f g i").split() + shorter_arr: Strs = Str("a b c d e").split() + longer_arr: Strs = Str("a b c d e f g h i j").split() + + assert arr == identical_arr + assert arr != different_arr + assert arr != shorter_arr + assert arr != longer_arr + assert shorter_arr < arr + assert longer_arr > arr + + # Test against a Python list and a tuple + list_equal = ["a", "b", "c", "d", "e", "f", "g", "h"] + list_different = ["a", "b", "c", "d", "x", "f", "g", "h"] + tuple_equal = ("a", "b", "c", "d", "e", "f", "g", "h") + tuple_different = ("a", "b", "c", "d", "e", "f", "g", "i") + + assert arr == list_equal + assert arr != list_different + assert arr == tuple_equal + assert arr != tuple_different + + # Test against a generator of unknown length + generator_equal = (x for x in "a b c d e f g h".split()) + generator_different = (x for x in "a b c d e f g i".split()) + generator_shorter = (x for x in "a b c d e".split()) + generator_longer = (x for x in "a b c d e f g h i j".split()) + + assert arr == generator_equal + assert arr != generator_different + assert arr != generator_shorter + assert arr != generator_longer + + def test_unit_strs_sequence_slicing(): native = "1, 2, 3, 4, 5, 6" big = Str(native) @@ -232,6 +338,27 @@ def test_unit_globals(): assert sz.find("abcdef", "bcdef") == 1 assert sz.find("abcdef", "x") == -1 + assert sz.rfind("abcdef", "bcdef") == 1 + assert sz.rfind("abcdef", "x") == -1 + + # Corner-cases for `find` and `rfind`, when we pass empty strings + assert sz.find("abcdef", "") == "abcdef".find("") + assert sz.rfind("abcdef", "") == "abcdef".rfind("") + assert sz.find("abcdef", "", 1) == "abcdef".find("", 1) + assert sz.rfind("abcdef", "", 1) == "abcdef".rfind("", 1) + assert sz.find("abcdef", "", 1, 3) == "abcdef".find("", 1, 3) + assert sz.rfind("abcdef", "", 1, 3) == "abcdef".rfind("", 1, 3) + assert sz.find("", "abcdef") == "".find("abcdef") + assert sz.rfind("", "abcdef") == "".rfind("abcdef") + + # Compare partitioning functions + assert sz.partition("abcdef", "c") == ("ab", "c", "def") + assert sz.rpartition("abcdef", "c") == ("ab", "c", "def") + + with pytest.raises(ValueError): + sz.partition("abcdef", "") + with pytest.raises(ValueError): + sz.rpartition("abcdef", "") assert sz.count("abcdef", "x") == 0 assert sz.count("aaaaa", "a") == 5 From 3f9f1977cc41c0429a34c6d997abf2ba5a911cfb Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 1 Mar 2024 23:01:27 +0000 Subject: [PATCH 07/13] Fix: `split_iter(..., keepseparator=True)` --- python/lib.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib.c b/python/lib.c index ac11741d..73119dbe 100644 --- a/python/lib.c +++ b/python/lib.c @@ -2449,9 +2449,9 @@ static PyObject *SplitIteratorType_next(SplitIterator *self) { else { result_start = self->text.start; result_length = found - self->text.start; - result_length += self->match_length * self->include_match; self->text.start = found + self->match_length; self->text.length -= result_length + self->match_length; + result_length += self->match_length * self->include_match; } self->max_parts--; } From 1c813e43160d08888486d82d1ded6253ce0ff863 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 1 Mar 2024 23:06:39 +0000 Subject: [PATCH 08/13] Improve: Dynamic-dispatch for `sz_generate` --- c/lib.c | 5 +++++ include/stringzilla/stringzilla.h | 13 +++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/c/lib.c b/c/lib.c index e58e215f..5721265c 100644 --- a/c/lib.c +++ b/c/lib.c @@ -302,3 +302,8 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_ sz_charset_invert(&set); return sz_rfind_charset(h, h_length, &set); } + +SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, + sz_random_generator_t generator, void *generator_user_data) { + return sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); +} diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index b216409b..00e72a9c 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -497,9 +497,13 @@ SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); * @param generate Callback producing random numbers given the generator state. * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. */ -SZ_PUBLIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, +SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, void *generator); +/** @copydoc sz_generate */ +SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, + sz_random_generator_t generate, void *generator); + /** * @brief Similar to `memcpy`, copies contents of one string into another. * The behavior is undefined if the strings overlap. @@ -2984,7 +2988,7 @@ SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { return sz_true_k; } -SZ_PUBLIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, +SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, sz_random_generator_t generator, void *generator_user_data) { sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); @@ -5324,6 +5328,11 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_ return sz_rfind_charset(h, h_length, &set); } +SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, + sz_random_generator_t generator, void *generator_user_data) { + return sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); +} + #endif #pragma endregion From 40eb12d1f4499c1b5e88cfb3eb0f4e6688eb71a2 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 1 Mar 2024 23:10:30 +0000 Subject: [PATCH 09/13] Fix: Missing `pytest.mark.skipif` for NumPy and Arrow --- .github/workflows/prerelease.yml | 8 ++++---- scripts/test.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/prerelease.yml b/.github/workflows/prerelease.yml index 4ecaa0e9..e9abbc34 100644 --- a/.github/workflows/prerelease.yml +++ b/.github/workflows/prerelease.yml @@ -78,7 +78,7 @@ jobs: - name: Build Python run: | python -m pip install --upgrade pip - pip install pytest pytest-repeat + pip install pytest pytest-repeat numpy pyarrow python -m pip install . - name: Test Python run: pytest scripts/test.py -s -x @@ -160,7 +160,7 @@ jobs: - name: Build Python run: | python -m pip install --upgrade pip - pip install pytest pytest-repeat + pip install pytest pytest-repeat numpy pyarrow python -m pip install . - name: Test Python run: pytest scripts/test.py -s -x @@ -235,7 +235,7 @@ jobs: - name: Build Python run: | python -m pip install --upgrade pip - pip install pytest pytest-repeat + pip install pytest pytest-repeat numpy pyarrow python -m pip install . - name: Test Python run: pytest scripts/test.py -s -x @@ -270,7 +270,7 @@ jobs: - name: Build Python run: | python -m pip install --upgrade pip - pip install pytest pytest-repeat + pip install pytest pytest-repeat numpy pyarrow python -m pip install . - name: Test Python run: pytest scripts/test.py -s -x diff --git a/scripts/test.py b/scripts/test.py index be2aa719..2c180cbc 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -587,6 +587,7 @@ def test_edit_distances(): @pytest.mark.repeat(30) @pytest.mark.parametrize("first_length", [20, 100]) @pytest.mark.parametrize("second_length", [20, 100]) +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") def test_edit_distance_random(first_length: int, second_length: int): a = get_random_string(length=first_length) b = get_random_string(length=second_length) From 6798c4e16c708103beffe3be87b64ae977fb3e6d Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 2 Mar 2024 00:32:48 +0000 Subject: [PATCH 10/13] Fix: no `return` in `void` funcs in C 99 --- Cargo.lock | 2 +- c/lib.c | 2 +- include/stringzilla/stringzilla.h | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 080c826a..651779fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,7 +19,7 @@ checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" [[package]] name = "stringzilla" -version = "3.2.0" +version = "3.3.1" dependencies = [ "cc", ] diff --git a/c/lib.c b/c/lib.c index 5721265c..980c069f 100644 --- a/c/lib.c +++ b/c/lib.c @@ -305,5 +305,5 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_ SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, sz_random_generator_t generator, void *generator_user_data) { - return sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); + sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); } diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 00e72a9c..5ed681e3 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -498,7 +498,7 @@ SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. */ SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); + sz_random_generator_t generate, void *generator); /** @copydoc sz_generate */ SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, @@ -2989,7 +2989,7 @@ SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { } SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { + sz_random_generator_t generator, void *generator_user_data) { sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); @@ -5330,7 +5330,7 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_ SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, sz_random_generator_t generator, void *generator_user_data) { - return sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); + sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); } #endif From 30398bc392b7e353f987a98b1b43e84fb901e72f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 2 Mar 2024 05:16:38 +0000 Subject: [PATCH 11/13] Add: Similarity measures for Rust --- README.md | 37 +- c/lib.c | 21 + include/stringzilla/stringzilla.h | 188 ++++---- rust/lib.rs | 733 +++++++++++++++++++++++------- 4 files changed, 735 insertions(+), 244 deletions(-) diff --git a/README.md b/README.md index b4905bdd..e6592fff 100644 --- a/README.md +++ b/README.md @@ -1093,9 +1093,27 @@ __`STRINGZILLA_BUILD_SHARED`, `STRINGZILLA_BUILD_TEST`, `STRINGZILLA_BUILD_BENCH ## Quick Start: Rust 🦀 StringZilla is available as a Rust crate. -It currently covers only the most basic functionality, but is planned to be extended to cover the full C++ API. +Some of the interfaces will look familiar to the users of the `memchr` crate. + +```rust +use stringzilla::sz; + +// Identical to `memchr::memmem::find` and `memchr::memmem::rfind` functions +sz::find("Hello, world!", "world") // 7 +sz::rfind("Hello, world!", "world") // 7 + +// Generalizations of `memchr::memrchr[123]` +sz::find_char_from("Hello, world!", "world") // 2 +sz::rfind_char_from("Hello, world!", "world") // 11 +``` + +Unlike `memchr`, the throughput of `stringzilla` is [high in both normal and reverse-order searches][memchr-benchmarks]. +It also provides no constraints on the size of the character set, while `memchr` allows only 1, 2, or 3 characters. +In addition to global functions, `stringzilla` provides a `StringZilla` extension trait: ```rust +use stringzilla::StringZilla; + let my_string: String = String::from("Hello, world!"); let my_str = my_string.as_str(); let my_cow_str = Cow::from(&my_string); @@ -1113,6 +1131,23 @@ assert_eq!(my_str.sz_find("world"), Some(7)); assert_eq!(my_cow_str.as_ref().sz_find("world"), Some(7)); ``` +The library also exposes Levenshtein and Hamming edit-distances for byte-arrays and UTF-8 strings, as well as Needleman-Wunch alignment scores. + +```rust +use stringzilla::sz; + +// Handling arbitrary byte arrays: +sz::edit_distance("Hello, world!", "Hello, world?"); // 1 +sz::hamming_distance("Hello, world!", "Hello, world?"); // 1 +sz::alignment_score("Hello, world!", "Hello, world?", sz::unary_substitution_costs(), -1); // -1 + +// Handling UTF-8 strings: +sz::hamming_distance_utf8("αβγδ", "αγγδ") // 1 +sz::edit_distance_utf8("façade", "facade") // 1 +``` + +[memchr-benchmarks]: https://github.com/ashvardanian/memchr_vs_stringzilla + ## Quick Start: Swift 🍏 StringZilla is available as a Swift package. diff --git a/c/lib.c b/c/lib.c index 980c069f..bd837ff5 100644 --- a/c/lib.c +++ b/c/lib.c @@ -255,6 +255,20 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_chars return sz_dispatch_table.rfind_from_set(text, length, set); } +SZ_DYNAMIC sz_size_t sz_hamming_distance( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound) { + return sz_hamming_distance_serial(a, a_length, b, b_length, bound); +} + +SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound) { + return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); +} + SZ_DYNAMIC sz_size_t sz_edit_distance( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // @@ -262,6 +276,13 @@ SZ_DYNAMIC sz_size_t sz_edit_distance( // return sz_dispatch_table.edit_distance(a, a_length, b, b_length, bound, alloc); } +SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { + return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); +} + SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 5ed681e3..cc5b7b7f 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -768,8 +768,12 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz * @see sz_hamming_distance_utf8 * @see https://en.wikipedia.org/wiki/Hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); +SZ_DYNAMIC sz_size_t sz_hamming_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, + sz_size_t bound); + +/** @copydoc sz_hamming_distance */ +SZ_PUBLIC sz_size_t sz_hamming_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, + sz_size_t bound); /** * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. @@ -787,8 +791,12 @@ SZ_PUBLIC sz_size_t sz_hamming_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr * @see sz_hamming_distance * @see https://en.wikipedia.org/wiki/Hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); +SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, + sz_size_t bound); + +/** @copydoc sz_hamming_distance_utf8 */ +SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, + sz_size_t bound); typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); @@ -839,11 +847,15 @@ SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_ * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance * @see https://en.wikipedia.org/wiki/Levenshtein_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); +SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); +/** @copydoc sz_edit_distance_utf8 */ +SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); + /** * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. @@ -2651,6 +2663,74 @@ SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // return result; } +SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound) { + + sz_size_t const min_length = sz_min_of_two(a_length, b_length); + sz_size_t const max_length = sz_max_of_two(a_length, b_length); + sz_cptr_t const a_end = a + min_length; + bound = bound == 0 ? max_length : bound; + + // Walk through both strings using SWAR and counting the number of differing characters. + sz_size_t distance = max_length - min_length; +#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN + if (min_length >= SZ_SWAR_THRESHOLD) { + sz_u64_vec_t a_vec, b_vec, match_vec; + for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { + a_vec.u64 = sz_u64_load(a).u64; + b_vec.u64 = sz_u64_load(b).u64; + match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); + distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); + } + } +#endif + + for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } + return sz_min_of_two(distance, bound); +} + +SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound) { + + sz_cptr_t const a_end = a + a_length; + sz_cptr_t const b_end = b + b_length; + sz_size_t distance = 0; + + sz_rune_t a_rune, b_rune; + sz_rune_length_t a_rune_length, b_rune_length; + + if (bound) { + for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { + _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); + _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); + distance += (a_rune != b_rune); + } + // If one string has more runes, we need to go through the tail. + if (distance < bound) { + for (; a < a_end && distance < bound; a += a_rune_length, ++distance) + _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); + + for (; b < b_end && distance < bound; b += b_rune_length, ++distance) + _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); + } + } + else { + for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { + _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); + _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); + distance += (a_rune != b_rune); + } + // If one string has more runes, we need to go through the tail. + for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); + for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); + } + return distance; +} + /** * @brief Largest prime number that fits into 31 bits. * @see https://mersenneforum.org/showthread.php?t=3471 @@ -5075,81 +5155,6 @@ SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_ sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); } -SZ_PUBLIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_size_t const min_length = sz_min_of_two(a_length, b_length); - sz_size_t const max_length = sz_max_of_two(a_length, b_length); - sz_cptr_t const a_end = a + min_length; - bound = bound == 0 ? max_length : bound; - - // Walk through both strings using SWAR and counting the number of differing characters. - sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - if (min_length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec, match_vec; - for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { - a_vec.u64 = sz_u64_load(a).u64; - b_vec.u64 = sz_u64_load(b).u64; - match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); - distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); - } - } -#endif - - for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_cptr_t const a_end = a + a_length; - sz_cptr_t const b_end = b + b_length; - sz_size_t distance = 0; - - sz_rune_t a_rune, b_rune; - sz_rune_length_t a_rune_length, b_rune_length; - - if (bound) { - for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - if (distance < bound) { - for (; a < a_end && distance < bound; a += a_rune_length, ++distance) - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - - for (; b < b_end && distance < bound; b += b_rune_length, ++distance) - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - } - else { - for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - return distance; -} - -SZ_PUBLIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); -} - #if !SZ_DYNAMIC_DISPATCH SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { @@ -5266,6 +5271,20 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_chars #endif } +SZ_DYNAMIC sz_size_t sz_hamming_distance( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound) { + return sz_hamming_distance_serial(a, a_length, b, b_length, bound); +} + +SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound) { + return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); +} + SZ_DYNAMIC sz_size_t sz_edit_distance( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // @@ -5277,6 +5296,13 @@ SZ_DYNAMIC sz_size_t sz_edit_distance( // #endif } +SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { + return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); +} + SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { diff --git a/rust/lib.rs b/rust/lib.rs index fff99606..f7a9c997 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -1,111 +1,123 @@ #![cfg_attr(not(test), no_std)] -use core::ffi::c_void; - -// Import the functions from the StringZilla C library. -extern "C" { - fn sz_find( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; - - fn sz_rfind( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; - - fn sz_find_char_from( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; - - fn sz_rfind_char_from( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; - - fn sz_find_char_not_from( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; - - fn sz_rfind_char_not_from( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; - - fn sz_edit_distance( - haystack1: *const c_void, - haystack1_length: usize, - haystack2: *const c_void, - haystack2_length: usize, - bound: usize, - allocator: *const c_void, - ) -> usize; - - fn sz_alignment_score( - haystack1: *const c_void, - haystack1_length: usize, - haystack2: *const c_void, - haystack2_length: usize, - matrix: *const c_void, - gap: i8, - allocator: *const c_void, - ) -> isize; -} - -/// The [StringZilla] trait provides a collection of string searching and manipulation functionalities. -pub trait StringZilla -where - N: AsRef<[u8]>, -{ - /// Locates first matching substring. Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. Similar - /// to `strstr(haystack, needle)` in LibC, but requires known length. - fn sz_find(&self, needle: N) -> Option; - - /// Locates the last matching substring. - fn sz_rfind(&self, needle: N) -> Option; - - /// Finds the first character in the haystack, that is present in the needle. - fn sz_find_char_from(&self, needles: N) -> Option; - - /// Finds the last character in the haystack, that is present in the needle. - fn sz_rfind_char_from(&self, needles: N) -> Option; - - /// Finds the first character in the haystack, that is __not__ present in the needle. - fn sz_find_char_not_from(&self, needles: N) -> Option; - - /// Finds the last character in the haystack, that is __not__ present in the needle. - fn sz_rfind_char_not_from(&self, needles: N) -> Option; - - /// Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. - /// Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. - fn sz_edit_distance(&self, needle: N) -> usize; - - /// Computes Needleman–Wunsch alignment score for two strings. Often used in bioinformatics and cheminformatics. - /// Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. - fn sz_alignment_score(&self, needle: N, matrix: [[i8; 256]; 256], gap: i8) -> isize; -} +/// The `sz` module provides a collection of string searching and manipulation functionality, +/// designed for high efficiency and compatibility with no_std environments. This module offers +/// various utilities for byte string manipulation, including search, reverse search, and +/// edit-distance calculations, suitable for a wide range of applications from basic string +/// processing to complex text analysis tasks. + +pub mod sz { + + use core::ffi::c_void; + + // Import the functions from the StringZilla C library. + extern "C" { + fn sz_find( + haystack: *const c_void, + haystack_length: usize, + needle: *const c_void, + needle_length: usize, + ) -> *const c_void; + + fn sz_rfind( + haystack: *const c_void, + haystack_length: usize, + needle: *const c_void, + needle_length: usize, + ) -> *const c_void; + + fn sz_find_char_from( + haystack: *const c_void, + haystack_length: usize, + needle: *const c_void, + needle_length: usize, + ) -> *const c_void; + + fn sz_rfind_char_from( + haystack: *const c_void, + haystack_length: usize, + needle: *const c_void, + needle_length: usize, + ) -> *const c_void; + + fn sz_find_char_not_from( + haystack: *const c_void, + haystack_length: usize, + needle: *const c_void, + needle_length: usize, + ) -> *const c_void; + + fn sz_rfind_char_not_from( + haystack: *const c_void, + haystack_length: usize, + needle: *const c_void, + needle_length: usize, + ) -> *const c_void; + + fn sz_edit_distance( + haystack1: *const c_void, + haystack1_length: usize, + haystack2: *const c_void, + haystack2_length: usize, + bound: usize, + allocator: *const c_void, + ) -> usize; + + fn sz_edit_distance_utf8( + haystack1: *const c_void, + haystack1_length: usize, + haystack2: *const c_void, + haystack2_length: usize, + bound: usize, + allocator: *const c_void, + ) -> usize; + + fn sz_hamming_distance( + haystack1: *const c_void, + haystack1_length: usize, + haystack2: *const c_void, + haystack2_length: usize, + bound: usize, + ) -> usize; + + fn sz_hamming_distance_utf8( + haystack1: *const c_void, + haystack1_length: usize, + haystack2: *const c_void, + haystack2_length: usize, + bound: usize, + ) -> usize; + + fn sz_alignment_score( + haystack1: *const c_void, + haystack1_length: usize, + haystack2: *const c_void, + haystack2_length: usize, + matrix: *const c_void, + gap: i8, + allocator: *const c_void, + ) -> isize; + } -impl StringZilla for T -where - T: AsRef<[u8]>, - N: AsRef<[u8]>, -{ - fn sz_find(&self, needle: N) -> Option { - let haystack_ref = self.as_ref(); + /// Locates the first matching substring within `haystack` that equals `needle`. + /// This function is similar to the `memmem()` function in LibC, but, unlike `strstr()`, + /// it requires the length of both haystack and needle to be known beforehand. + /// + /// # Arguments + /// + /// * `haystack`: The byte slice to search. + /// * `needle`: The byte slice to find within the haystack. + /// + /// # Returns + /// + /// An `Option` representing the starting index of the first occurrence of `needle` + /// within `haystack` if found, otherwise `None`. + pub fn find(haystack: H, needle: N) -> Option + where + H: AsRef<[u8]>, + N: AsRef<[u8]>, + { + let haystack_ref = haystack.as_ref(); let needle_ref = needle.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); @@ -127,8 +139,25 @@ where } } - fn sz_rfind(&self, needle: N) -> Option { - let haystack_ref = self.as_ref(); + /// Locates the last matching substring within `haystack` that equals `needle`. + /// This function is useful for finding the most recent or last occurrence of a pattern + /// within a byte slice. + /// + /// # Arguments + /// + /// * `haystack`: The byte slice to search. + /// * `needle`: The byte slice to find within the haystack. + /// + /// # Returns + /// + /// An `Option` representing the starting index of the last occurrence of `needle` + /// within `haystack` if found, otherwise `None`. + pub fn rfind(haystack: H, needle: N) -> Option + where + H: AsRef<[u8]>, + N: AsRef<[u8]>, + { + let haystack_ref = haystack.as_ref(); let needle_ref = needle.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); @@ -150,8 +179,25 @@ where } } - fn sz_find_char_from(&self, needles: N) -> Option { - let haystack_ref = self.as_ref(); + /// Finds the index of the first character in `haystack` that is also present in `needles`. + /// This function is particularly useful for parsing and tokenization tasks where a set of + /// delimiter characters is used. + /// + /// # Arguments + /// + /// * `haystack`: The byte slice to search. + /// * `needles`: The set of bytes to search for within the haystack. + /// + /// # Returns + /// + /// An `Option` representing the index of the first occurrence of any byte from + /// `needles` within `haystack`, if found, otherwise `None`. + pub fn find_char_from(haystack: H, needles: N) -> Option + where + H: AsRef<[u8]>, + N: AsRef<[u8]>, + { + let haystack_ref = haystack.as_ref(); let needles_ref = needles.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); @@ -172,8 +218,25 @@ where } } - fn sz_rfind_char_from(&self, needles: N) -> Option { - let haystack_ref = self.as_ref(); + /// Finds the index of the last character in `haystack` that is also present in `needles`. + /// This can be used to find the last occurrence of any character from a specified set, + /// useful in parsing scenarios such as finding the last delimiter in a string. + /// + /// # Arguments + /// + /// * `haystack`: The byte slice to search. + /// * `needles`: The set of bytes to search for within the haystack. + /// + /// # Returns + /// + /// An `Option` representing the index of the last occurrence of any byte from + /// `needles` within `haystack`, if found, otherwise `None`. + pub fn rfind_char_from(haystack: H, needles: N) -> Option + where + H: AsRef<[u8]>, + N: AsRef<[u8]>, + { + let haystack_ref = haystack.as_ref(); let needles_ref = needles.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); @@ -194,8 +257,25 @@ where } } - fn sz_find_char_not_from(&self, needles: N) -> Option { - let haystack_ref = self.as_ref(); + /// Finds the index of the first character in `haystack` that is not present in `needles`. + /// This function is useful for skipping over a known set of characters and finding the + /// first character that does not belong to that set. + /// + /// # Arguments + /// + /// * `haystack`: The byte slice to search. + /// * `needles`: The set of bytes that should not be matched within the haystack. + /// + /// # Returns + /// + /// An `Option` representing the index of the first occurrence of any byte not in + /// `needles` within `haystack`, if found, otherwise `None`. + pub fn find_char_not_from(haystack: H, needles: N) -> Option + where + H: AsRef<[u8]>, + N: AsRef<[u8]>, + { + let haystack_ref = haystack.as_ref(); let needles_ref = needles.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); @@ -216,8 +296,25 @@ where } } - fn sz_rfind_char_not_from(&self, needles: N) -> Option { - let haystack_ref = self.as_ref(); + /// Finds the index of the last character in `haystack` that is not present in `needles`. + /// Useful for text processing tasks such as trimming trailing characters that belong to + /// a specified set. + /// + /// # Arguments + /// + /// * `haystack`: The byte slice to search. + /// * `needles`: The set of bytes that should not be matched within the haystack. + /// + /// # Returns + /// + /// An `Option` representing the index of the last occurrence of any byte not in + /// `needles` within `haystack`, if found, otherwise `None`. + pub fn rfind_char_not_from(haystack: H, needles: N) -> Option + where + H: AsRef<[u8]>, + N: AsRef<[u8]>, + { + let haystack_ref = haystack.as_ref(); let needles_ref = needles.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); @@ -238,57 +335,288 @@ where } } - fn sz_edit_distance(&self, needle: N) -> usize { - let haystack_ref = self.as_ref(); - let needle_ref = needle.as_ref(); - let haystack_length = haystack_ref.len(); - let needle_length = needle_ref.len(); - let haystack_pointer = haystack_ref.as_ptr() as _; - let needle_pointer = needle_ref.as_ptr() as _; + /// Computes the Levenshtein edit distance between two strings, using the Wagner-Fisher + /// algorithm. This measure is widely used in applications like spell-checking, DNA sequence + /// analysis. + /// + /// # Arguments + /// + /// * `first`: The first byte slice. + /// * `second`: The second byte slice. + /// * `bound`: The maximum distance to compute, allowing for early exit. + /// + /// # Returns + /// + /// A `usize` representing the minimum number of single-character edits (insertions, + /// deletions, or substitutions) required to change `first` into `second`. + pub fn edit_distance_bounded(first: F, second: S, bound: usize) -> usize + where + F: AsRef<[u8]>, + S: AsRef<[u8]>, + { + let first_ref = first.as_ref(); + let second_ref = second.as_ref(); + let first_length = first_ref.len(); + let second_length = second_ref.len(); + let first_pointer = first_ref.as_ptr() as _; + let second_pointer = second_ref.as_ptr() as _; unsafe { sz_edit_distance( - haystack_pointer, - haystack_length, - needle_pointer, - needle_length, + first_pointer, + first_length, + second_pointer, + second_length, // Upper bound on the distance, that allows us to exit early. If zero is // passed, the maximum possible distance will be equal to the length of // the longer input. - 0, + bound, // Uses the default allocator core::ptr::null(), ) } } - fn sz_alignment_score(&self, needle: N, matrix: [[i8; 256]; 256], gap: i8) -> isize { - let haystack_ref = self.as_ref(); - let needle_ref = needle.as_ref(); - let haystack_length = haystack_ref.len(); - let needle_length = needle_ref.len(); - let haystack_pointer = haystack_ref.as_ptr() as _; - let needle_pointer = needle_ref.as_ptr() as _; + /// Computes the Levenshtein edit distance between two UTF8 strings, using the Wagner-Fisher + /// algorithm. This measure is widely used in applications like spell-checking. + /// + /// # Arguments + /// + /// * `first`: The first byte slice. + /// * `second`: The second byte slice. + /// * `bound`: The maximum distance to compute, allowing for early exit. + /// + /// # Returns + /// + /// A `usize` representing the minimum number of single-character edits (insertions, + /// deletions, or substitutions) required to change `first` into `second`. + pub fn edit_distance_utf8_bounded(first: F, second: S, bound: usize) -> usize + where + F: AsRef<[u8]>, + S: AsRef<[u8]>, + { + let first_ref = first.as_ref(); + let second_ref = second.as_ref(); + let first_length = first_ref.len(); + let second_length = second_ref.len(); + let first_pointer = first_ref.as_ptr() as _; + let second_pointer = second_ref.as_ptr() as _; + unsafe { + sz_edit_distance_utf8( + first_pointer, + first_length, + second_pointer, + second_length, + // Upper bound on the distance, that allows us to exit early. If zero is + // passed, the maximum possible distance will be equal to the length of + // the longer input. + bound, + // Uses the default allocator + core::ptr::null(), + ) + } + } + + /// Computes the Levenshtein edit distance between two strings, using the Wagner-Fisher + /// algorithm. This measure is widely used in applications like spell-checking, DNA sequence + /// analysis. + /// + /// # Arguments + /// + /// * `first`: The first byte slice. + /// * `second`: The second byte slice. + /// + /// # Returns + /// + /// A `usize` representing the minimum number of single-character edits (insertions, + /// deletions, or substitutions) required to change `first` into `second`. + pub fn edit_distance(first: F, second: S) -> usize + where + F: AsRef<[u8]>, + S: AsRef<[u8]>, + { + edit_distance_bounded(first, second, 0) + } + + /// Computes the Levenshtein edit distance between two UTF8 strings, using the Wagner-Fisher + /// algorithm. This measure is widely used in applications like spell-checking. + /// + /// # Arguments + /// + /// * `first`: The first byte slice. + /// * `second`: The second byte slice. + /// + /// # Returns + /// + /// A `usize` representing the minimum number of single-character edits (insertions, + /// deletions, or substitutions) required to change `first` into `second`. + pub fn edit_distance_utf8(first: F, second: S) -> usize + where + F: AsRef<[u8]>, + S: AsRef<[u8]>, + { + edit_distance_utf8_bounded(first, second, 0) + } + + /// Computes the Hamming edit distance between two strings, counting the number of substituted characters. + /// Difference in length is added to the result as well. + /// + /// # Arguments + /// + /// * `first`: The first byte slice. + /// * `second`: The second byte slice. + /// * `bound`: The maximum distance to compute, allowing for early exit. + /// + /// # Returns + /// + /// A `usize` representing the minimum number of single-character edits (substitutions) required to + /// change `first` into `second`. + pub fn hamming_distance_bounded(first: F, second: S, bound: usize) -> usize + where + F: AsRef<[u8]>, + S: AsRef<[u8]>, + { + let first_ref = first.as_ref(); + let second_ref = second.as_ref(); + let first_length = first_ref.len(); + let second_length = second_ref.len(); + let first_pointer = first_ref.as_ptr() as _; + let second_pointer = second_ref.as_ptr() as _; + unsafe { + sz_hamming_distance( + first_pointer, + first_length, + second_pointer, + second_length, + // Upper bound on the distance, that allows us to exit early. If zero is + // passed, the maximum possible distance will be equal to the length of + // the longer input. + bound, + ) + } + } + + /// Computes the Hamming edit distance between two UTF8 strings, counting the number of substituted + /// variable-length characters. Difference in length is added to the result as well. + /// + /// # Arguments + /// + /// * `first`: The first byte slice. + /// * `second`: The second byte slice. + /// * `bound`: The maximum distance to compute, allowing for early exit. + /// + /// # Returns + /// + /// A `usize` representing the minimum number of single-character edits (substitutions) required to + /// change `first` into `second`. + pub fn hamming_distance_utf8_bounded(first: F, second: S, bound: usize) -> usize + where + F: AsRef<[u8]>, + S: AsRef<[u8]>, + { + let first_ref = first.as_ref(); + let second_ref = second.as_ref(); + let first_length = first_ref.len(); + let second_length = second_ref.len(); + let first_pointer = first_ref.as_ptr() as _; + let second_pointer = second_ref.as_ptr() as _; + unsafe { + sz_hamming_distance_utf8( + first_pointer, + first_length, + second_pointer, + second_length, + // Upper bound on the distance, that allows us to exit early. If zero is + // passed, the maximum possible distance will be equal to the length of + // the longer input. + bound, + ) + } + } + + /// Computes the Hamming edit distance between two strings, counting the number of substituted characters. + /// Difference in length is added to the result as well. + /// + /// # Arguments + /// + /// * `first`: The first byte slice. + /// * `second`: The second byte slice. + /// + /// # Returns + /// + /// A `usize` representing the minimum number of single-character edits (substitutions) required to + /// change `first` into `second`. + pub fn hamming_distance(first: F, second: S) -> usize + where + F: AsRef<[u8]>, + S: AsRef<[u8]>, + { + hamming_distance_bounded(first, second, 0) + } + + /// Computes the Hamming edit distance between two UTF8 strings, counting the number of substituted + /// variable-length characters. Difference in length is added to the result as well. + /// + /// # Arguments + /// + /// * `first`: The first byte slice. + /// * `second`: The second byte slice. + /// + /// # Returns + /// + /// A `usize` representing the minimum number of single-character edits (substitutions) required to + /// change `first` into `second`. + pub fn hamming_distance_utf8(first: F, second: S) -> usize + where + F: AsRef<[u8]>, + S: AsRef<[u8]>, + { + hamming_distance_utf8_bounded(first, second, 0) + } + + /// Computes the Needleman-Wunsch alignment score for two strings. This function is + /// particularly used in bioinformatics for sequence alignment but is also applicable in + /// other domains requiring detailed comparison between two strings, including gap and + /// substitution penalties. + /// + /// # Arguments + /// + /// * `first`: The first byte slice to align. + /// * `second`: The second byte slice to align. + /// * `matrix`: The substitution matrix used for scoring. + /// * `gap`: The penalty for each gap introduced during alignment. + /// + /// # Returns + /// + /// An `isize` representing the total alignment score, where higher scores indicate better + /// alignment between the two strings, considering the specified gap penalties and + /// substitution matrix. + pub fn alignment_score(first: F, second: S, matrix: [[i8; 256]; 256], gap: i8) -> isize + where + F: AsRef<[u8]>, + S: AsRef<[u8]>, + { + let first_ref = first.as_ref(); + let second_ref = second.as_ref(); + let first_length = first_ref.len(); + let second_length = second_ref.len(); + let first_pointer = first_ref.as_ptr() as _; + let second_pointer = second_ref.as_ptr() as _; unsafe { sz_alignment_score( - haystack_pointer, - haystack_length, - needle_pointer, - needle_length, + first_pointer, + first_length, + second_pointer, + second_length, matrix.as_ptr() as _, gap, core::ptr::null(), ) } } -} -#[cfg(test)] -mod tests { - use std::borrow::Cow; - - use crate::StringZilla; - - fn unary_substitution_costs() -> [[i8; 256]; 256] { + /// The default substitution matrix for the Needleman-Wunsch alignment algorithm, + /// which will produce distances equal to the negative Levenshtein edit distance. + pub fn unary_substitution_costs() -> [[i8; 256]; 256] { let mut result = [[0; 256]; 256]; for i in 0..256 { @@ -299,39 +627,120 @@ mod tests { result } +} + +/// The [StringZilla] trait provides a collection of string searching and manipulation functionalities. +pub trait StringZilla +where + N: AsRef<[u8]>, +{ + fn sz_find(&self, needle: N) -> Option; + fn sz_rfind(&self, needle: N) -> Option; + fn sz_find_char_from(&self, needles: N) -> Option; + fn sz_rfind_char_from(&self, needles: N) -> Option; + fn sz_find_char_not_from(&self, needles: N) -> Option; + fn sz_rfind_char_not_from(&self, needles: N) -> Option; + fn sz_edit_distance(&self, other: N) -> usize; + fn sz_alignment_score(&self, other: N, matrix: [[i8; 256]; 256], gap: i8) -> isize; +} + +impl StringZilla for T +where + T: AsRef<[u8]>, + N: AsRef<[u8]>, +{ + fn sz_find(&self, needle: N) -> Option { + sz::find(self, needle) + } + + fn sz_rfind(&self, needle: N) -> Option { + sz::rfind(self, needle) + } + + fn sz_find_char_from(&self, needles: N) -> Option { + sz::find_char_from(self, needles) + } + + fn sz_rfind_char_from(&self, needles: N) -> Option { + sz::rfind_char_from(self, needles) + } + + fn sz_find_char_not_from(&self, needles: N) -> Option { + sz::find_char_not_from(self, needles) + } + + fn sz_rfind_char_not_from(&self, needles: N) -> Option { + sz::rfind_char_not_from(self, needles) + } + + fn sz_edit_distance(&self, other: N) -> usize { + sz::edit_distance(self, other) + } + + fn sz_alignment_score(&self, other: N, matrix: [[i8; 256]; 256], gap: i8) -> isize { + sz::alignment_score(self, other, matrix, gap) + } +} + +#[cfg(test)] +mod tests { + use std::borrow::Cow; + + use crate::sz; + use crate::StringZilla; + + #[test] + fn hamming() { + assert_eq!(sz::hamming_distance("hello", "hello"), 0); + assert_eq!(sz::hamming_distance("hello", "hell"), 1); + assert_eq!(sz::hamming_distance("abc", "adc"), 1); + + assert_eq!(sz::hamming_distance_bounded("abcdefgh", "ABCDEFGH", 2), 2); + assert_eq!(sz::hamming_distance_utf8("αβγδ", "αγγδ"), 1); + } #[test] fn levenshtein() { - assert_eq!("hello".sz_edit_distance("hell"), 1); - assert_eq!("hello".sz_edit_distance("hell"), 1); - assert_eq!("abc".sz_edit_distance(""), 3); - assert_eq!("abc".sz_edit_distance("ac"), 1); - assert_eq!("abc".sz_edit_distance("a_bc"), 1); - assert_eq!("abc".sz_edit_distance("adc"), 1); - assert_eq!("ggbuzgjux{}l".sz_edit_distance("gbuzgjux{}l"), 1); - assert_eq!("abcdefgABCDEFG".sz_edit_distance("ABCDEFGabcdefg"), 14); - assert_eq!("fitting".sz_edit_distance("kitty"), 4); - assert_eq!("smitten".sz_edit_distance("mitten"), 1); + assert_eq!(sz::edit_distance("hello", "hell"), 1); + assert_eq!(sz::edit_distance("hello", "hell"), 1); + assert_eq!(sz::edit_distance("abc", ""), 3); + assert_eq!(sz::edit_distance("abc", "ac"), 1); + assert_eq!(sz::edit_distance("abc", "a_bc"), 1); + assert_eq!(sz::edit_distance("abc", "adc"), 1); + assert_eq!(sz::edit_distance("fitting", "kitty"), 4); + assert_eq!(sz::edit_distance("smitten", "mitten"), 1); + assert_eq!(sz::edit_distance("ggbuzgjux{}l", "gbuzgjux{}l"), 1); + assert_eq!(sz::edit_distance("abcdefgABCDEFG", "ABCDEFGabcdefg"), 14); + + assert_eq!(sz::edit_distance_bounded("fitting", "kitty", 2), 2); + assert_eq!(sz::edit_distance_utf8("façade", "facade"), 1); } #[test] fn needleman() { - let costs_vector = unary_substitution_costs(); - assert_eq!("listen".sz_alignment_score("silent", costs_vector, -1), -4); + let costs_vector = sz::unary_substitution_costs(); + assert_eq!( + sz::alignment_score("listen", "silent", costs_vector, -1), + -4 + ); assert_eq!( - "abcdefgABCDEFG".sz_alignment_score("ABCDEFGabcdefg", costs_vector, -1), + sz::alignment_score("abcdefgABCDEFG", "ABCDEFGabcdefg", costs_vector, -1), -14 ); - assert_eq!("hello".sz_alignment_score("hello", costs_vector, -1), 0); - assert_eq!("hello".sz_alignment_score("hell", costs_vector, -1), -1); + assert_eq!(sz::alignment_score("hello", "hello", costs_vector, -1), 0); + assert_eq!(sz::alignment_score("hello", "hell", costs_vector, -1), -1); } #[test] - fn basics() { + fn search() { let my_string: String = String::from("Hello, world!"); let my_str: &str = my_string.as_str(); let my_cow_str: Cow<'_, str> = Cow::from(&my_string); + // Identical to `memchr::memmem::find` and `memchr::memmem::rfind` functions + assert_eq!(sz::find("Hello, world!", "world"), Some(7)); + assert_eq!(sz::rfind("Hello, world!", "world"), Some(7)); + // Use the generic function with a String assert_eq!(my_string.sz_find("world"), Some(7)); assert_eq!(my_string.sz_rfind("world"), Some(7)); From 6998bcf43086023dea429afbcf94a6d77fb629de Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 2 Mar 2024 05:50:30 +0000 Subject: [PATCH 12/13] Fix: Handle `NULL` PRNGs --- c/lib.c | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/c/lib.c b/c/lib.c index bd837ff5..bcfe7834 100644 --- a/c/lib.c +++ b/c/lib.c @@ -23,6 +23,7 @@ typedef sz_size_t size_t; // Reuse the type definition we've inferred from `stri #else typedef __SIZE_TYPE__ size_t; // For GCC/Clang #endif +int rand(void) { return 0; } void free(void *start) { sz_unused(start); } void *malloc(size_t length) { sz_unused(length); @@ -324,7 +325,13 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_ return sz_rfind_charset(h, h_length, &set); } +sz_u64_t _sz_random_generator(void *empty_state) { + sz_unused(empty_state); + return (sz_u64_t)rand(); +} + SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, sz_random_generator_t generator, void *generator_user_data) { + if (!generator) generator = _sz_random_generator; sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); } From 34997a3fc7d3c947a59356012848af9ddb3f99c7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 2 Mar 2024 05:56:34 +0000 Subject: [PATCH 13/13] Fix: Syntax issues --- python/lib.c | 10 ++++++---- rust/lib.rs | 45 +++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/python/lib.c b/python/lib.c index 73119dbe..d4b74e84 100644 --- a/python/lib.c +++ b/python/lib.c @@ -56,7 +56,8 @@ static sz_string_view_t temporary_memory = {NULL, 0}; * native `mmap` module, as it exposes the address of the mapping in memory. */ typedef struct { - PyObject_HEAD; + PyObject ob_base; + #if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) HANDLE file_handle; HANDLE mapping_handle; @@ -80,7 +81,8 @@ typedef struct { * - Str(File("some-path.txt"), from=0, to=sys.maxint) */ typedef struct { - PyObject_HEAD; + PyObject ob_base; + PyObject *parent; sz_cptr_t start; sz_size_t length; @@ -93,7 +95,7 @@ typedef struct { * which might be more memory-friendly, than greedily invoking `str.split`. */ typedef struct { - PyObject_HEAD; + PyObject ob_base; PyObject *text_object; //< For reference counting PyObject *separator_object; //< For reference counting @@ -125,7 +127,7 @@ typedef struct { * for faster sorting, shuffling, joins, and lookups. */ typedef struct { - PyObject_HEAD; + PyObject ob_base; enum { STRS_CONSECUTIVE_32, diff --git a/rust/lib.rs b/rust/lib.rs index 46dd4cde..5a968417 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -625,8 +625,19 @@ pub mod sz { } } - /// The default substitution matrix for the Needleman-Wunsch alignment algorithm, - /// which will produce distances equal to the negative Levenshtein edit distance. + /// Generates a default substitution matrix for use with the Needleman-Wunsch + /// alignment algorithm. This matrix is initialized such that diagonal entries + /// (representing matching characters) are zero, and off-diagonal entries + /// (representing mismatches) are -1. This setup effectively produces distances + /// equal to the negative Levenshtein edit distance, suitable for basic sequence + /// alignment tasks where all mismatches are penalized equally and there are no + /// rewards for matches. + /// + /// # Returns + /// + /// A 256x256 array of `i8`, where each element represents the substitution cost + /// between two characters (byte values). Matching characters are assigned a cost + /// of 0, and non-matching characters are assigned a cost of -1. pub fn unary_substitution_costs() -> [[i8; 256]; 256] { let mut result = [[0; 256]; 256]; @@ -639,10 +650,36 @@ pub mod sz { result } + /// Randomizes the contents of a given byte slice `text` using characters from + /// a specified `alphabet`. This function mutates `text` in place, replacing each + /// byte with a random one from `alphabet`. It is designed for situations where + /// you need to generate random strings or data sequences based on a specific set + /// of characters, such as generating random DNA sequences or testing inputs. + /// + /// # Type Parameters + /// + /// * `T`: The type of the text to be randomized. Must be mutable and convertible to a byte slice. + /// * `A`: The type of the alphabet. Must be convertible to a byte slice. + /// + /// # Arguments + /// + /// * `text`: A mutable reference to the data to randomize. This data will be mutated in place. + /// * `alphabet`: A reference to the byte slice representing the alphabet to use for randomization. + /// + /// # Examples + /// + /// ``` + /// use stringzilla::sz; + /// let mut my_text = vec![0; 10]; // A buffer to randomize + /// let alphabet = b"ACTG"; // Using a DNA alphabet + /// sz::randomize(&mut my_text, &alphabet); + /// ``` + /// + /// After than, `my_text` is filled with random 'A', 'C', 'T', or 'G' values. pub fn randomize(text: &mut T, alphabet: &A) where - T: AsMut<[u8]> + ?Sized, // Relaxing Sized restriction for T - A: AsRef<[u8]> + ?Sized, // Relaxing Sized restriction for A + T: AsMut<[u8]> + ?Sized, // Allows for mutable references to dynamically sized types. + A: AsRef<[u8]> + ?Sized, // Allows for references to dynamically sized types. { let text_slice = text.as_mut(); let alphabet_slice = alphabet.as_ref();