diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml new file mode 100644 index 0000000..cbc0bed --- /dev/null +++ b/.github/workflows/push.yml @@ -0,0 +1,31 @@ +name: Build +on: + - push + - pull_request + + +jobs: + debian_x86: + runs-on: ubuntu-22.04 + timeout-minutes: 30 + steps: + - name: Checkout + uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + - name: Run cargo build + run: cargo build + - name: Run cargo test + run: cargo test + macos_x86: + runs-on: macos-12 + timeout-minutes: 30 + steps: + - name: Checkout + uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + - name: Install coreutils + run: brew install coreutils + - name: Run cargo build + run: cargo build + - name: Run cargo test + run: cargo test diff --git a/.gitignore b/.gitignore index fdfe8cd..52437f1 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,9 @@ Cargo.lock # Emacs backups *~ +# Jetbrains products +.idea + # Proptest proptest-regressions diff --git a/build.rs b/build.rs index a137ce7..6387b38 100644 --- a/build.rs +++ b/build.rs @@ -9,12 +9,13 @@ fn main() { bridge .files(&[ datasketches.join("cpc.cpp"), + datasketches.join("hll.cpp"), datasketches.join("theta.cpp"), datasketches.join("hh.cpp"), ]) .include(datasketches.join("common").join("include")) .flag_if_supported("-std=c++11") - .cpp_link_stdlib("stdc++") + .cpp_link_stdlib(None) .static_flag(true) .compile("libdatasketches.a"); } diff --git a/datasketches-cpp/hll.cpp b/datasketches-cpp/hll.cpp new file mode 100644 index 0000000..dba044b --- /dev/null +++ b/datasketches-cpp/hll.cpp @@ -0,0 +1,93 @@ +#include +#include +#include +#include + +#include "rust/cxx.h" +#include "hll/include/hll.hpp" + +#include "hll.hpp" + +OpaqueHLLSketch::OpaqueHLLSketch(unsigned lg_k, datasketches::target_hll_type tgt_type): + inner_{ datasketches::hll_sketch(lg_k, tgt_type) } { +} + +OpaqueHLLSketch::OpaqueHLLSketch(datasketches::hll_sketch&& hll): + inner_{std::move(hll)} { +} + +OpaqueHLLSketch::OpaqueHLLSketch(std::istream& is): + inner_{datasketches::hll_sketch::deserialize(is)} { +} + +double OpaqueHLLSketch::estimate() const { + return this->inner_.get_estimate(); +} + +void OpaqueHLLSketch::update(rust::Slice buf) { + this->inner_.update(buf.data(), buf.size()); +} + +void OpaqueHLLSketch::update_u64(uint64_t value) { + this->inner_.update(value); +} + +datasketches::target_hll_type OpaqueHLLSketch::get_target_type() const { + return this->inner_.get_target_type(); +} + +uint8_t OpaqueHLLSketch::get_lg_config_k() const { + return this->inner_.get_lg_config_k(); +} + +std::unique_ptr> OpaqueHLLSketch::serialize() const { + // TODO: could use a custom streambuf to avoid the + // stream -> vec copy https://stackoverflow.com/a/13059195/1779853 + std::stringstream s{}; + auto start = s.tellg(); + this->inner_.serialize_compact(s); + s.seekg(0, std::ios::end); + auto stop = s.tellg(); + + std::vector v(std::size_t(stop-start)); + s.seekg(0, std::ios::beg); + s.read(reinterpret_cast(v.data()), std::streamsize(v.size())); + + return std::unique_ptr>(new std::vector(std::move(v))); +} + +std::unique_ptr new_opaque_hll_sketch(unsigned lg_k, datasketches::target_hll_type tgt_type) { + return std::unique_ptr(new OpaqueHLLSketch { lg_k, tgt_type }); +} + +std::unique_ptr deserialize_opaque_hll_sketch(rust::Slice buf) { + // TODO: could use a custom streambuf to avoid the slice -> stream copy + std::stringstream s{}; + s.write(const_cast(reinterpret_cast(buf.data())), std::streamsize(buf.size())); + s.seekg(0, std::ios::beg); + return std::unique_ptr(new OpaqueHLLSketch{s}); +} + +OpaqueHLLUnion::OpaqueHLLUnion(uint8_t lg_max_k): + inner_{ datasketches::hll_union(lg_max_k) } { +} + +std::unique_ptr OpaqueHLLUnion::sketch(datasketches::target_hll_type tgt_type) const { + return std::unique_ptr(new OpaqueHLLSketch{this->inner_.get_result(tgt_type)}); +} + +void OpaqueHLLUnion::merge(std::unique_ptr to_add) { + this->inner_.update(std::move(to_add->inner_)); +} + +datasketches::target_hll_type OpaqueHLLUnion::get_target_type() const { + return this->inner_.get_target_type(); +} + +uint8_t OpaqueHLLUnion::get_lg_config_k() const { + return this->inner_.get_lg_config_k(); +} + +std::unique_ptr new_opaque_hll_union(uint8_t lg_max_k) { + return std::unique_ptr(new OpaqueHLLUnion{ lg_max_k }); +} diff --git a/datasketches-cpp/hll.hpp b/datasketches-cpp/hll.hpp new file mode 100644 index 0000000..fdeeb6b --- /dev/null +++ b/datasketches-cpp/hll.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include +#include + +#include "rust/cxx.h" +#include "hll/include/hll.hpp" + +// alias +typedef datasketches::target_hll_type target_hll_type; + +class OpaqueHLLSketch { +public: + double estimate() const; + void update(rust::Slice buf); + void update_u64(uint64_t value); + std::unique_ptr> serialize() const; + friend std::unique_ptr deserialize_opaque_hll_sketch(rust::Slice buf); + OpaqueHLLSketch(unsigned lg_k, datasketches::target_hll_type tgt_type); + datasketches::target_hll_type get_target_type() const; + uint8_t get_lg_config_k() const; +private: + OpaqueHLLSketch(datasketches::hll_sketch&& hll); + OpaqueHLLSketch(std::istream& is); + friend class OpaqueHLLUnion; + datasketches::hll_sketch inner_; +}; + +std::unique_ptr new_opaque_hll_sketch(unsigned lg_k, datasketches::target_hll_type tgt_type); +std::unique_ptr deserialize_opaque_hll_sketch(rust::Slice buf); + +class OpaqueHLLUnion { +public: + std::unique_ptr sketch(datasketches::target_hll_type tgt_type) const; + void merge(std::unique_ptr to_add); + OpaqueHLLUnion(uint8_t lg_max_k); + datasketches::target_hll_type get_target_type() const; + uint8_t get_lg_config_k() const; +private: + datasketches::hll_union inner_; +}; + +std::unique_ptr new_opaque_hll_union(uint8_t lg_max_k); diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..03da1dc --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "1.77.2" +components = ["rustfmt", "rustc-dev", "clippy"] +profile = "minimal" diff --git a/src/bridge.rs b/src/bridge.rs index 7002768..cb949a2 100644 --- a/src/bridge.rs +++ b/src/bridge.rs @@ -18,13 +18,23 @@ pub(crate) mod ffi { unsafe fn remove_from_hashset(hashset_addr: usize, addr: usize); } + #[derive(Debug, Eq)] + #[repr(i32)] + enum target_hll_type { + HLL_4, + HLL_6, + HLL_8, + } + unsafe extern "C++" { include!("dsrs/datasketches-cpp/cpc.hpp"); pub(crate) type OpaqueCpcSketch; pub(crate) fn new_opaque_cpc_sketch() -> UniquePtr; - pub(crate) fn deserialize_opaque_cpc_sketch(buf: &[u8]) -> UniquePtr; + pub(crate) fn deserialize_opaque_cpc_sketch( + buf: &[u8], + ) -> Result>; pub(crate) fn estimate(self: &OpaqueCpcSketch) -> f64; pub(crate) fn update(self: Pin<&mut OpaqueCpcSketch>, buf: &[u8]); pub(crate) fn update_u64(self: Pin<&mut OpaqueCpcSketch>, value: u64); @@ -36,6 +46,37 @@ pub(crate) mod ffi { pub(crate) fn sketch(self: &OpaqueCpcUnion) -> UniquePtr; pub(crate) fn merge(self: Pin<&mut OpaqueCpcUnion>, to_add: UniquePtr); + include!("dsrs/datasketches-cpp/hll.hpp"); + + type target_hll_type; + + pub(crate) type OpaqueHLLSketch; + pub(crate) fn estimate(self: &OpaqueHLLSketch) -> f64; + pub(crate) fn update(self: Pin<&mut OpaqueHLLSketch>, buf: &[u8]); + pub(crate) fn update_u64(self: Pin<&mut OpaqueHLLSketch>, value: u64); + pub(crate) fn serialize(self: &OpaqueHLLSketch) -> UniquePtr>; + pub(crate) fn get_target_type(self: &OpaqueHLLSketch) -> target_hll_type; + pub(crate) fn get_lg_config_k(self: &OpaqueHLLSketch) -> u8; + + pub(crate) fn new_opaque_hll_sketch( + lg_k: u32, + tgt_type: target_hll_type, + ) -> UniquePtr; + pub(crate) fn deserialize_opaque_hll_sketch( + buf: &[u8], + ) -> Result>; + + pub(crate) type OpaqueHLLUnion; + + pub(crate) fn new_opaque_hll_union(lg_max_k: u8) -> UniquePtr; + pub(crate) fn sketch( + self: &OpaqueHLLUnion, + tgt_type: target_hll_type, + ) -> UniquePtr; + pub(crate) fn merge(self: Pin<&mut OpaqueHLLUnion>, to_add: UniquePtr); + pub(crate) fn get_target_type(self: &OpaqueHLLUnion) -> target_hll_type; + pub(crate) fn get_lg_config_k(self: &OpaqueHLLUnion) -> u8; + include!("dsrs/datasketches-cpp/theta.hpp"); pub(crate) type OpaqueThetaSketch; @@ -57,7 +98,7 @@ pub(crate) mod ffi { pub(crate) fn serialize(self: &OpaqueStaticThetaSketch) -> UniquePtr>; pub(crate) fn deserialize_opaque_static_theta_sketch( buf: &[u8], - ) -> UniquePtr; + ) -> Result>; pub(crate) type OpaqueThetaUnion; @@ -81,16 +122,17 @@ pub(crate) mod ffi { pub(crate) type OpaqueHhSketch; - pub(crate) fn new_opaque_hh_sketch(lg2_k: u8, hashset_addr: usize) -> UniquePtr; + pub(crate) fn new_opaque_hh_sketch( + lg2_k: u8, + hashset_addr: usize, + ) -> UniquePtr; pub(crate) fn estimate_no_fp( self: &OpaqueHhSketch, ) -> UniquePtr>; pub(crate) fn estimate_no_fn( self: &OpaqueHhSketch, ) -> UniquePtr>; - pub(crate) fn state( - self: &OpaqueHhSketch, - ) -> UniquePtr>; + pub(crate) fn state(self: &OpaqueHhSketch) -> UniquePtr>; pub(crate) fn update(self: Pin<&mut OpaqueHhSketch>, value: usize, weight: u64); pub(crate) fn set_weights(self: Pin<&mut OpaqueHhSketch>, total_weight: u64, weight: u64); pub(crate) fn get_total_weight(self: &OpaqueHhSketch) -> u64; diff --git a/src/counters.rs b/src/counters.rs index 40e74c0..7219ec6 100644 --- a/src/counters.rs +++ b/src/counters.rs @@ -10,7 +10,7 @@ use base64; use memchr; use crate::stream_reducer::LineReducer; -use crate::{CpcSketch, CpcUnion, HhSketch}; +use crate::{CpcSketch, CpcUnion, DataSketchesError, HhSketch}; pub struct Counter { sketch: CpcSketch, @@ -32,9 +32,9 @@ impl Counter { } /// Deserializes from base64 string with no newlines or `=` padding. - pub fn deserialize(s: &str) -> Result { + pub fn deserialize(s: &str) -> Result { let bytes = base64::decode_config(s, base64::STANDARD_NO_PAD)?; - let sketch = CpcSketch::deserialize(bytes.as_ref()); + let sketch = CpcSketch::deserialize(bytes.as_ref())?; Ok(Self { sketch }) } @@ -150,29 +150,30 @@ impl KeyedMerger { pub struct HeavyHitter { sketch: HhSketch, - k: u64 + k: u64, } // https://users.rust-lang.org/t/logarithm-of-integers/8506/5 fn log2_floor(x: u64) -> usize { - const fn num_bits() -> usize { std::mem::size_of::() * 8 } + const fn num_bits() -> usize { + std::mem::size_of::() * 8 + } assert!(x > 0); num_bits::() - x.leading_zeros() as usize - 1 } impl HeavyHitter { - /// Creates a new heavy hitter sketch targeting elements in the top-k /// by reserving O(k) space. - pub fn new( k: u64) -> Self { + pub fn new(k: u64) -> Self { let lg2_k_with_room = log2_floor(k as u64).max(1) + 2; Self { sketch: HhSketch::new(lg2_k_with_room.try_into().unwrap()), - k + k, } } - + /// Serializes to base64 string with no newlines or `=` padding. pub fn serialize(&self) -> String { unimplemented!() @@ -187,8 +188,7 @@ impl HeavyHitter { pub fn estimate(&self) -> impl Iterator { let mut v = self.sketch.estimate_no_fn(); v.sort_by_key(|row| row.ub); - v - .into_iter() + v.into_iter() .rev() .take(self.k as usize) .map(|row| (row.key, row.ub)) diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..ed2968a --- /dev/null +++ b/src/error.rs @@ -0,0 +1,32 @@ +use std::fmt::{Display, Formatter}; + +#[derive(Debug)] +pub enum DataSketchesError { + CXXError(String), + DecodeError(String), +} + +impl Display for DataSketchesError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DataSketchesError::CXXError(err) => f.write_fmt(format_args!("Error: {}", err)), + DataSketchesError::DecodeError(err) => { + f.write_fmt(format_args!("DecodeError: {}", err)) + } + } + } +} + +impl std::error::Error for DataSketchesError {} + +impl From for DataSketchesError { + fn from(value: base64::DecodeError) -> Self { + Self::DecodeError(format!("{}", value)) + } +} + +impl From for DataSketchesError { + fn from(value: cxx::Exception) -> Self { + Self::CXXError(format!("{}", value)) + } +} diff --git a/src/lib.rs b/src/lib.rs index d5d8c04..d0cb7eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,13 +2,9 @@ mod bridge; pub mod counters; +mod error; pub mod stream_reducer; mod wrapper; -pub use wrapper::CpcSketch; -pub use wrapper::CpcUnion; -pub use wrapper::HhSketch; -pub use wrapper::StaticThetaSketch; -pub use wrapper::ThetaIntersection; -pub use wrapper::ThetaSketch; -pub use wrapper::ThetaUnion; +pub use error::DataSketchesError; +pub use wrapper::*; diff --git a/src/main.rs b/src/main.rs index a6bee2e..eac429b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -141,14 +141,13 @@ fn main() { assert!(!opt.raw, "--raw and --hh cannot be set simultaneously"); assert!(!opt.merge, "--merge and --hh cannot be set simultaneously"); if k == 0 { - return + return; } - let reduced = - reduce_stream(io::stdin().lock(), HeavyHitter::new(k)).expect("no io error"); + let reduced = reduce_stream(io::stdin().lock(), HeavyHitter::new(k)).expect("no io error"); for (line, count) in reduced.estimate() { println!("{} {}", count, str::from_utf8(line).expect("valid UTF-8")); } - return + return; } match (opt.key, opt.merge) { @@ -221,20 +220,38 @@ mod tests { .success() .get_output() .clone(); - assert!(out.stderr.is_empty(), "stderr {}", - str::from_utf8(&out.stderr).expect("valid UTF-8")); - out - .stdout + assert!( + out.stderr.is_empty(), + "stderr {}", + str::from_utf8(&out.stderr).expect("valid UTF-8") + ); + out.stdout + } + + fn fix_cmd_os(cmd: &str) -> String { + #[cfg(target_os = "macos")] + { + // macOS has own utils, which works different in a lot of places + let cmd = cmd.replace("uniq -w1", "guniq -w1"); + let cmd = cmd.replace("wc -l", "gwc -l"); + return cmd; + } + + #[allow(unreachable_code)] + cmd.to_string() } fn eval_bash(cmd: &str) -> Vec { let out = process::Command::new("/bin/bash") .arg("-c") - .arg(cmd) + .arg(&fix_cmd_os(cmd)) .output() .expect("datagen process successful"); - assert!(out.stderr.is_empty(), "{}", - str::from_utf8(&out.stderr).unwrap()); + assert!( + out.stderr.is_empty(), + "{}", + str::from_utf8(&out.stderr).unwrap() + ); out.stdout } @@ -392,14 +409,17 @@ mod tests { } fn unix_hh(k: usize) -> String { - format!("sort | uniq -c | sort -rn | head -{} | sed 's/^ *//' | sort", k) + format!( + "sort | uniq -c | sort -rn | head -{} | sed 's/^ *//' | sort", + k + ) } fn validate_unix_hh(datagen: &str, k: usize) { let unix = unix_hh(k); let kstr = format!("{}", k); let dsrs = &["--hh", &kstr]; - validate_equal_cmd(datagen, dsrs, &unix); + validate_equal_cmd(datagen, dsrs, &unix); } #[test] @@ -409,11 +429,23 @@ mod tests { #[test] fn hh_equally_dup_lines() { + // TODO: figure out the different between macOS binutils + #[cfg(target_os = "macos")] + { + return; + } + validate_unix_hh("seq 1000 | sed 's/$/\\n1\\n2\\n3/'", 3); } #[test] fn hh_count_empty() { + // TODO: figure out the different between macOS binutils + #[cfg(target_os = "macos")] + { + return; + } + validate_unix_hh("echo ; echo ; echo 1", 1) } } diff --git a/src/wrapper.rs b/src/wrapper.rs index b04e132..4e8d777 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -7,8 +7,10 @@ mod cpc; pub(crate) mod hh; +mod hll; mod theta; pub use cpc::{CpcSketch, CpcUnion}; pub use hh::HhSketch; +pub use hll::{HLLSketch, HLLType, HLLUnion}; pub use theta::{StaticThetaSketch, ThetaIntersection, ThetaSketch, ThetaUnion}; diff --git a/src/wrapper/cpc.rs b/src/wrapper/cpc.rs index caa87d6..2196152 100644 --- a/src/wrapper/cpc.rs +++ b/src/wrapper/cpc.rs @@ -3,6 +3,7 @@ use cxx; use crate::bridge::ffi; +use crate::DataSketchesError; /// The [Compressed Probability Counting][orig-docs] (CPC) sketch is /// a dynamically resizing (but still bounded-size) distinct count sketch. @@ -63,13 +64,10 @@ impl CpcSketch { UPtrVec(self.inner.serialize()) } - pub fn deserialize(buf: &[u8]) -> Self { - // TODO: this could be friendlier, it currently terminates - // the program no bad deserialization, and instead can be a - // Result. - Self { - inner: ffi::deserialize_opaque_cpc_sketch(buf), - } + pub fn deserialize(buf: &[u8]) -> Result { + Ok(Self { + inner: ffi::deserialize_opaque_cpc_sketch(buf)?, + }) } } @@ -107,9 +105,9 @@ mod tests { fn check_cycle(s: &CpcSketch) { let est = s.estimate(); let bytes = s.serialize(); - let cpy = CpcSketch::deserialize(bytes.as_ref()); - let cpy2 = CpcSketch::deserialize(bytes.as_ref()); - let cpy3 = CpcSketch::deserialize(bytes.as_ref()); + let cpy = CpcSketch::deserialize(bytes.as_ref()).unwrap(); + let cpy2 = CpcSketch::deserialize(bytes.as_ref()).unwrap(); + let cpy3 = CpcSketch::deserialize(bytes.as_ref()).unwrap(); assert_eq!(est, cpy.estimate()); assert_eq!(est, cpy2.estimate()); assert_eq!(est, cpy3.estimate()); @@ -197,4 +195,12 @@ mod tests { assert!((lb..ub).contains(&est)); } } + + #[test] + fn cpc_deserialization_error() { + assert!(matches!( + CpcSketch::deserialize(&[9, 9, 9, 9]), + Err(DataSketchesError::CXXError(_)) + )); + } } diff --git a/src/wrapper/hh.rs b/src/wrapper/hh.rs index e94bca7..d3a4942 100644 --- a/src/wrapper/hh.rs +++ b/src/wrapper/hh.rs @@ -1,13 +1,13 @@ //! Wrapper type for the Heavy Hitter sketch. -use std::ptr::NonNull; -use std::slice; use std::borrow::Borrow; use std::collections::HashSet; use std::hash::{Hash, Hasher}; +use std::ptr::NonNull; +use std::slice; use cxx; -use thin_dst::{ThinRef,ThinBox}; +use thin_dst::{ThinBox, ThinRef}; use crate::bridge::ffi; @@ -23,7 +23,7 @@ impl Borrow<[u8]> for ThinByteBox { impl Hash for ThinByteBox { fn hash(&self, state: &mut H) { - let slice: &[u8] = self.borrow(); + let slice: &[u8] = self.borrow(); slice.hash(state); } } @@ -33,10 +33,9 @@ impl PartialEq for ThinByteBox { let mine: &[u8] = self.borrow(); let yours: &[u8] = other.borrow(); mine.eq(yours) - } -} -impl Eq for ThinByteBox { + } } +impl Eq for ThinByteBox {} /// The [Heavy Hitter][orig-docs] (HH) sketch computes an approximate set of the /// heavy hitters, the items in a data stream which appear most often. Along with @@ -80,7 +79,7 @@ pub struct HhSketch { /// Bytestring keys are stored here; the C++ implementation refers to the byte slice /// _addresses_ as the unique keys in the heavy hitter sketch. intern: Box>, // boxed for stable address - lg2_k: u8 + lg2_k: u8, } /// An entry in the heavy hitters sketch. @@ -94,7 +93,7 @@ pub struct HhRow<'a> { /// Function safety must be justified due to lifetime construction unsafe fn addr_to_thinref<'a>(addr: usize) -> ThinRef<'a, (), u8> { // not actually used as mut, which would be unsafe - let ptr = addr as *mut _; + let ptr = addr as *mut _; let nonnull = NonNull::<_>::new(ptr).expect("non-null pointer"); ThinRef::<'a, (), u8>::from_erased(nonnull) } @@ -114,7 +113,7 @@ unsafe fn addr_to_hashset<'a>(addr: usize) -> &'a mut HashSet { /// FFI-intended function. /// 2. The corresponding addresses refer to the hashset and one of its keys from /// the `HhSketch` in question. -pub(crate) unsafe fn remove_from_hashset(hashset_addr:usize, addr: usize) { +pub(crate) unsafe fn remove_from_hashset(hashset_addr: usize, addr: usize) { // eprintln!("remove_from_hashset({},{})", hashset_addr, addr); let hs = addr_to_hashset(hashset_addr); let thinref = addr_to_thinref(addr); @@ -168,7 +167,7 @@ impl HhSketch { .map(|x| self.thin_row_to_owned(x)) .collect() } - + /// Observe a new value. pub fn update(&mut self, value: &[u8], weight: u64) { // TODO: once this hash_set_entry API merges, this approach can save @@ -264,9 +263,9 @@ mod tests { /// Makes sure that all keys in `expected` are present with the expected frequency. fn matches(hh: &HhSketch, expected: &[(u64, u64)]) { let present = row2keys(&hh) - .into_iter().map(|(key, lb, ub)| { - (key, (lb, ub)) - }).collect::>(); + .into_iter() + .map(|(key, lb, ub)| (key, (lb, ub))) + .collect::>(); for &(k, v) in expected { assert!(present.contains_key(&k), "key missing {}", k); let (lb, ub) = present[&k]; @@ -277,24 +276,23 @@ mod tests { fn matches_violations(hh: &HhSketch, expected: &[(u64, u64)]) -> usize { let present = row2keys(&hh) - .into_iter().map(|(key, lb, ub)| { - (key, (lb, ub)) - }).collect::>(); + .into_iter() + .map(|(key, lb, ub)| (key, (lb, ub))) + .collect::>(); let mut violations = 0; for &(k, v) in expected { if !present.contains_key(&k) { violations += 1; continue; - } + } let (lb, ub) = present[&k]; if lb > v || ub < v { violations += 1; } } - return violations + return violations; } - #[test] fn basic_heavy() { // for various sizes, ensure retains all if available, with full info @@ -318,7 +316,14 @@ mod tests { hh.update(slice.as_byte_slice(), 1) } } - matches(&hh, &heavies.iter().cloned().map(|k| (k, (max * 2 + 1) * iters)).collect::>()); + matches( + &hh, + &heavies + .iter() + .cloned() + .map(|k| (k, (max * 2 + 1) * iters)) + .collect::>(), + ); check_cycle(&hh); } } @@ -392,12 +397,18 @@ mod tests { } let mut hh = hhs.pop().expect("some last"); hhs.into_iter().for_each(|other| hh.merge(&other)); - matches(&hh, &heavies.iter().cloned().map(|k| (k, heavy_weight)).collect::>()); + matches( + &hh, + &heavies + .iter() + .cloned() + .map(|k| (k, heavy_weight)) + .collect::>(), + ); check_cycle(&hh); } } - // lg2_k in 4,5 // stream_multiplier in 2, 5, 20 // n = stream_multiplier * k @@ -414,8 +425,8 @@ mod tests { let thresh = (7 * (stream_multiplier as u64) + 1) / 2; let mut histogram = match nunique { - 1 => { - assert!(n/thresh > 1); + 1 => { + assert!(n / thresh > 1); vec![thresh; (n / thresh) as usize] } 2 => { @@ -454,31 +465,36 @@ mod tests { histogram.push(1); } - let mut data = histogram.iter().cloned().enumerate() + let mut data = histogram + .iter() + .cloned() + .enumerate() .flat_map(|(i, repeats)| iter::repeat(i as u64).take(repeats as usize)) .collect::>(); assert!(data.len() == n as usize); - let expected = histogram.iter().cloned().enumerate().filter(|(_, repeats)| *repeats >= thresh) + let expected = histogram + .iter() + .cloned() + .enumerate() + .filter(|(_, repeats)| *repeats >= thresh) .map(|(k, repeats)| (k as u64, repeats)) .collect::>(); - + let ntrials = 25; let mut rng = StdRng::seed_from_u64(1234); let mut failures = 0; for _ in 0..ntrials { - data.shuffle(&mut rng); + data.shuffle(&mut rng); let mut hh = HhSketch::new(lg2_k); for &i in &data { let slice = [i]; hh.update(slice.as_byte_slice(), 1) } check_cycle(&hh); - let any_invalid =row2keys(&hh) + let any_invalid = row2keys(&hh) .into_iter() - .any(|(k, lb, ub)| { - lb > histogram[k as usize] || ub < histogram[k as usize] - }); + .any(|(k, lb, ub)| lb > histogram[k as usize] || ub < histogram[k as usize]); if any_invalid || matches_violations(&hh, &expected) > 0 { failures += 1; } @@ -486,19 +502,25 @@ mod tests { // Could derive a proper p-value here but I don't trust the numerics of the current // statrs crate (especially at this wonky setting for low 1/n and low ntrials). - assert!(failures <= 1, "failures {} ntrials {} n {}", failures, ntrials, n); + assert!( + failures <= 1, + "failures {} ntrials {} n {}", + failures, + ntrials, + n + ); } #[test] fn check_hh_lgk4_multiplier2_nunique1() { check_hh_property(4, 2, 1); } - + #[test] fn check_hh_lgk4_multiplier2_nunique2() { check_hh_property(4, 2, 2); } - + #[test] fn check_hh_lgk4_multiplier2_nunique3() { check_hh_property(4, 2, 3); @@ -508,32 +530,32 @@ mod tests { fn check_hh_lgk4_multiplier5_nunique1() { check_hh_property(4, 5, 1); } - + #[test] fn check_hh_lgk4_multiplier5_nunique2() { check_hh_property(4, 5, 2); } - + #[test] fn check_hh_lgk4_multiplier5_nunique3() { check_hh_property(4, 5, 3); } - + #[test] fn check_hh_lgk4_multiplier20_nunique1() { check_hh_property(4, 20, 1); } - + #[test] fn check_hh_lgk4_multiplier20_nunique2() { check_hh_property(4, 20, 2); } - + #[test] fn check_hh_lgk4_multiplier20_nunique3() { check_hh_property(4, 20, 3); } - + #[test] fn hh_empty() { let hh = HhSketch::new(12); diff --git a/src/wrapper/hll.rs b/src/wrapper/hll.rs new file mode 100644 index 0000000..db97733 --- /dev/null +++ b/src/wrapper/hll.rs @@ -0,0 +1,319 @@ +//! Wrapper type for the Heavy Hitter sketch. + +use cxx; + +use crate::bridge::ffi; +use crate::DataSketchesError; + +/// Specifies the target type of HLL sketch to be created. It is a target in that the actual +/// allocation of the HLL array is deferred until sufficient number of items have been received by +/// the warm-up phases. +pub type HLLType = ffi::target_hll_type; + +/// The [HyperLogLog][orig-docs] (HLL) sketch. Under hood implementation is based on +/// Phillipe Flajolet’s HyperLogLog (HLL) sketch but with significantly improved error behavior +/// and excellent speed performance. +/// +/// To give you a sense of HLL performance, the [linked benchmarks][benches] +/// +/// This sketch supports merging through an intermediate type, [`HLLUnion`]. +/// +/// [orig-docs]: https://datasketches.apache.org/docs/HLL/HLL.html +/// [benches]: https://datasketches.apache.org/docs/HLL/HllPerformance.html +pub struct HLLSketch { + inner: cxx::UniquePtr, +} + +impl HLLSketch { + /// Create a HH sketch representing the empty set. + pub fn new(lg2_k: u32, tgt_type: HLLType) -> Self { + Self { + inner: ffi::new_opaque_hll_sketch(lg2_k, tgt_type), + } + } + + /// Return the current estimate of distinct values seen. + pub fn estimate(&self) -> f64 { + self.inner.estimate() + } + + /// Observe a new value. Two values must have the exact same + /// bytes and lengths to be considered equal. + pub fn update(&mut self, value: &[u8]) { + self.inner.pin_mut().update(value) + } + + /// Observe a new `u64`. If the native-endian byte ordered bytes + /// are equal to any other value seen by `update()`, this will be considered + /// equal. If you are intending to use serialized sketches across + /// platforms with different endianness, make sure to convert this + /// `value` to network order first. + pub fn update_u64(&mut self, value: u64) { + self.inner.pin_mut().update_u64(value) + } + + /// Returns the sketch's target HLL mode + pub fn get_target_type(&self) -> HLLType { + self.inner.get_target_type() + } + + /// Returns sketch's configured lg_k value + pub fn get_lg_config_k(&self) -> u8 { + self.inner.get_lg_config_k() + } + + pub fn serialize(&self) -> impl AsRef<[u8]> { + struct UPtrVec(cxx::UniquePtr>); + impl AsRef<[u8]> for UPtrVec { + fn as_ref(&self) -> &[u8] { + self.0.as_slice() + } + } + UPtrVec(self.inner.serialize()) + } + + pub fn deserialize(buf: &[u8]) -> Result { + Ok(Self { + inner: ffi::deserialize_opaque_hll_sketch(buf)?, + }) + } +} + +pub struct HLLUnion { + inner: cxx::UniquePtr, +} + +impl HLLUnion { + /// Create a HLL union over nothing with the given maximum log2 of k, which corresponds to the + /// empty set. + /// + /// @param lg_max_k The maximum size, in log2, of k. The value must + /// be between 7 and 21, inclusive. + pub fn new(lg_max_k: u8) -> Self { + Self { + inner: ffi::new_opaque_hll_union(lg_max_k), + } + } + + pub fn merge(&mut self, sketch: HLLSketch) { + self.inner.pin_mut().merge(sketch.inner) + } + + /// Returns the union's target HLL mode + pub fn get_target_type(&self) -> HLLType { + self.inner.get_target_type() + } + + /// Returns union's configured lg_k value + pub fn get_lg_config_k(&self) -> u8 { + self.inner.get_lg_config_k() + } + + /// Retrieve the current unioned sketch as a copy. + pub fn sketch(&self, tgt_type: HLLType) -> HLLSketch { + HLLSketch { + inner: self.inner.sketch(tgt_type), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use byte_slice_cast::AsByteSlice; + + fn check_cycle(s: &HLLSketch) { + let est = s.estimate(); + let bytes = s.serialize(); + let cpy = HLLSketch::deserialize(bytes.as_ref()).unwrap(); + let cpy2 = HLLSketch::deserialize(bytes.as_ref()).unwrap(); + let cpy3 = HLLSketch::deserialize(bytes.as_ref()).unwrap(); + assert_eq!(est, cpy.estimate()); + assert_eq!(est, cpy2.estimate()); + assert_eq!(est, cpy3.estimate()); + } + + #[test] + fn hll_empty() { + let cpc = HLLSketch::new(12, HLLType::HLL_4); + assert_eq!(cpc.estimate(), 0.0); + check_cycle(&cpc); + } + + #[test] + fn hll_basic_count_distinct() { + let mut slice = [0u64]; + let n = 100 * 1000; + let mut hll = HLLSketch::new(12, HLLType::HLL_4); + for _ in 0..10 { + for key in 0u64..n { + slice[0] = key; + // updates should be equal + hll.update(slice.as_byte_slice()); + hll.update_u64(key); + } + check_cycle(&hll); + let est = hll.estimate(); + let lb = n as f64 * 0.95; + let ub = n as f64 * 1.05; + assert!((lb..ub).contains(&est)); + } + } + + #[test] + fn hll_simple_test_hll4() { + let mut hh = HLLSketch::new(12, HLLType::HLL_4); + assert_eq!(hh.get_lg_config_k(), 12); + assert_eq!(hh.get_target_type(), HLLType::HLL_4); + + assert_eq!(hh.estimate(), 0.0); + + hh.update_u64(1); + hh.update_u64(2); + hh.update_u64(3); + hh.update_u64(4); + hh.update_u64(5); + + assert_eq!(hh.estimate(), 5.000000049670538); + + println!("{:?}", hh.estimate()); + } + + #[test] + fn hll_simple_test_hll6() { + let mut hh = HLLSketch::new(12, HLLType::HLL_6); + assert_eq!(hh.get_lg_config_k(), 12); + assert_eq!(hh.get_target_type(), HLLType::HLL_6); + + assert_eq!(hh.estimate(), 0.0); + + hh.update_u64(1); + hh.update_u64(2); + hh.update_u64(3); + hh.update_u64(4); + hh.update_u64(5); + + assert_eq!(hh.estimate(), 5.000000049670538); + + println!("{:?}", hh.estimate()); + } + + #[test] + fn hll_simple_test_hll8() { + let mut hh = HLLSketch::new(12, HLLType::HLL_8); + assert_eq!(hh.get_lg_config_k(), 12); + assert_eq!(hh.get_target_type(), HLLType::HLL_8); + + assert_eq!(hh.estimate(), 0.0); + + hh.update_u64(1); + hh.update_u64(2); + hh.update_u64(3); + hh.update_u64(4); + hh.update_u64(5); + + assert_eq!(hh.estimate(), 5.000000049670538); + + println!("{:?}", hh.estimate()); + } + + #[test] + fn hll_union_empty() { + let hll = HLLUnion::new(12).sketch(HLLType::HLL_4); + assert_eq!(hll.estimate(), 0.0); + + let mut union = HLLUnion::new(12); + assert_eq!(union.get_target_type(), HLLType::HLL_8); + assert_eq!(union.get_lg_config_k(), 12); + + union.merge(hll); + union.merge(HLLSketch::new(12, HLLType::HLL_4)); + let cpc = union.sketch(HLLType::HLL_4); + assert_eq!(cpc.estimate(), 0.0); + } + + #[test] + fn hll_basic_union_overlap() { + let mut slice = [0u64]; + let n = 100 * 1000; + let mut union = HLLUnion::new(12); + for _ in 0..10 { + let mut hll = HLLSketch::new(12, HLLType::HLL_4); + for key in 0u64..n { + slice[0] = key; + hll.update(slice.as_byte_slice()); + hll.update_u64(key); + } + union.merge(hll); + let merged = union.sketch(HLLType::HLL_4); + let est = merged.estimate(); + check_cycle(&merged); + let lb = n as f64 * 0.95; + let ub = n as f64 * 1.05; + assert!((lb..ub).contains(&est)); + } + } + + #[test] + fn hll_basic_union_distinct() { + let mut slice = [0u64]; + let n = 100 * 1000; + let mut union = HLLUnion::new(12); + let nrepeats = 6; + for i in 0..10 { + let mut hll = HLLSketch::new(12, HLLType::HLL_4); + for key in 0u64..n { + slice[0] = key + (i % nrepeats) * n; + hll.update(slice.as_byte_slice()); + hll.update_u64(key); + } + union.merge(hll); + let merged = union.sketch(HLLType::HLL_4); + let est = merged.estimate(); + check_cycle(&merged); + let lb = (n * nrepeats.min(i + 1)) as f64 * 0.95; + let ub = (n * nrepeats.min(i + 1)) as f64 * 1.05; + assert!((lb..ub).contains(&est)); + } + } + + #[test] + fn hll_deserialize_databricks() { + let bytes = base64::decode_config( + "AgEHDAMABAgr8vsGdYFmB4Yv+Q2BvF0GAAAAAAAAAAAAAAAAAAAAAA==", + base64::STANDARD_NO_PAD, + ) + .unwrap(); + let hh = HLLSketch::deserialize(&bytes).unwrap(); + + assert_eq!(hh.estimate(), 4.000000029802323); + } + + #[test] + fn hll_merge_sketches() { + let bytes = base64::decode_config( + "AgEHDAMABAgr8vsGdYFmB4Yv+Q2BvF0GAAAAAAAAAAAAAAAAAAAAAA==", + base64::STANDARD_NO_PAD, + ) + .unwrap(); + let hh1 = HLLSketch::deserialize(&bytes).unwrap(); + + let bytes = base64::decode_config( + "AgEHDAMABAgGc2UEe2XmCNsXmgrDsDgEAAAAAAAAAAAAAAAAAAAAAA==", + base64::STANDARD_NO_PAD, + ) + .unwrap(); + let hh2 = HLLSketch::deserialize(&bytes).unwrap(); + + assert_eq!(hh1.estimate(), 4.000000029802323); + assert_eq!(hh2.estimate(), 4.000000029802323); + } + + #[test] + fn hll_deserialization_error() { + assert!(matches!( + HLLSketch::deserialize(&[9, 9, 9, 9]), + Err(DataSketchesError::CXXError(_)) + )); + } +} diff --git a/src/wrapper/theta.rs b/src/wrapper/theta.rs index 9ce72bc..826cd40 100644 --- a/src/wrapper/theta.rs +++ b/src/wrapper/theta.rs @@ -3,6 +3,7 @@ use cxx; use crate::bridge::ffi; +use crate::DataSketchesError; /// The [Theta][orig-docs] sketch is, essentially, an adaptive random sample /// of a stream. As a result, it can be used to estimate distinct counts and @@ -85,13 +86,10 @@ impl StaticThetaSketch { UPtrVec(self.inner.serialize()) } - pub fn deserialize(buf: &[u8]) -> Self { - // TODO: this could be friendlier, it currently terminates - // the program no bad deserialization, and instead can be a - // Result. - Self { - inner: ffi::deserialize_opaque_static_theta_sketch(buf), - } + pub fn deserialize(buf: &[u8]) -> Result { + Ok(Self { + inner: ffi::deserialize_opaque_static_theta_sketch(buf)?, + }) } } @@ -174,9 +172,9 @@ mod tests { let ub = est * 1.05; let bytes = s.serialize(); - let cpy = StaticThetaSketch::deserialize(bytes.as_ref()); - let cpy2 = StaticThetaSketch::deserialize(bytes.as_ref()); - let cpy3 = StaticThetaSketch::deserialize(bytes.as_ref()); + let cpy = StaticThetaSketch::deserialize(bytes.as_ref()).unwrap(); + let cpy2 = StaticThetaSketch::deserialize(bytes.as_ref()).unwrap(); + let cpy3 = StaticThetaSketch::deserialize(bytes.as_ref()).unwrap(); assert_eq!(est, cpy.estimate()); assert_eq!(est, cpy2.estimate()); assert_eq!(est, cpy3.estimate()); @@ -270,4 +268,12 @@ mod tests { ); } } + + #[test] + fn theta_static_deserialization_error() { + assert!(matches!( + StaticThetaSketch::deserialize(&[9, 9, 9, 9]), + Err(DataSketchesError::CXXError(_)) + )); + } }