From 5997114e8a46ea6ac5fe11ad2809d44c9f25bcac Mon Sep 17 00:00:00 2001 From: Chris Tam Date: Thu, 14 Dec 2023 14:21:20 -0500 Subject: [PATCH 1/4] Implement chunk iterators that drop the GIL --- python/rocksdict/rocksdict.pyi | 340 +++++++++++++++++++++++---------- src/iter.rs | 285 ++++++++++++++++++++++++--- src/rdict.rs | 104 +++++++++- src/util.rs | 19 ++ test/bench_rdict.py | 78 +++++++- test/test_rdict.py | 140 ++++++++++++++ 6 files changed, 833 insertions(+), 133 deletions(-) diff --git a/python/rocksdict/rocksdict.pyi b/python/rocksdict/rocksdict.pyi index f02a8be..9c3016f 100644 --- a/python/rocksdict/rocksdict.pyi +++ b/python/rocksdict/rocksdict.pyi @@ -1,37 +1,39 @@ -from typing import Any, Union, List, Iterator, Tuple, Dict, overload, Callable - -__all__ = ["Rdict", - "RdictIter", - "Options", - "WriteOptions", - "ReadOptions", - "DBPath", - "MemtableFactory", - "BlockBasedOptions", - "PlainTableFactoryOptions", - "CuckooTableOptions", - "UniversalCompactOptions", - "UniversalCompactionStopStyle", - "SliceTransform", - "DataBlockIndexType", - "BlockBasedIndexType", - "Cache", - "ChecksumType", - "DBCompactionStyle", - "DBCompressionType", - "DBRecoveryMode", - "Env", - "FifoCompactOptions", - "SstFileWriter", - "IngestExternalFileOptions", - "WriteBatch", - "ColumnFamily", - "AccessType", - "Snapshot", - "CompactOptions", - "BottommostLevelCompaction", - "KeyEncodingType", - "DbClosedError"] +from typing import Any, Union, List, Iterator, Optional, Tuple, Dict, Callable + +__all__ = [ + "Rdict", + "RdictIter", + "Options", + "WriteOptions", + "ReadOptions", + "DBPath", + "MemtableFactory", + "BlockBasedOptions", + "PlainTableFactoryOptions", + "CuckooTableOptions", + "UniversalCompactOptions", + "UniversalCompactionStopStyle", + "SliceTransform", + "DataBlockIndexType", + "BlockBasedIndexType", + "Cache", + "ChecksumType", + "DBCompactionStyle", + "DBCompressionType", + "DBRecoveryMode", + "Env", + "FifoCompactOptions", + "SstFileWriter", + "IngestExternalFileOptions", + "WriteBatch", + "ColumnFamily", + "AccessType", + "Snapshot", + "CompactOptions", + "BottommostLevelCompaction", + "KeyEncodingType", + "DbClosedError", +] class DataBlockIndexType: @staticmethod @@ -170,7 +172,12 @@ class MemtableFactory: class Options: def __init__(self, raw_mode: bool = False) -> None: ... @staticmethod - def load_latest(path: str, env: Env = Env(), ignore_unknown_options: bool = False, cache: Cache = Cache(8 * 1024 * 1024)) -> Tuple[Options, Dict[str, Options]]: ... + def load_latest( + path: str, + env: Env = Env(), + ignore_unknown_options: bool = False, + cache: Cache = Cache(8 * 1024 * 1024), + ) -> Tuple[Options, Dict[str, Options]]: ... def create_if_missing(self, create_if_missing: bool) -> None: ... def create_missing_column_families(self, create_missing_cfs: bool) -> None: ... def enable_statistics(self) -> None: ... @@ -178,7 +185,9 @@ class Options: def increase_parallelism(self, parallelism: int) -> None: ... def optimize_for_point_lookup(self, cache_size: int) -> None: ... def optimize_level_style_compaction(self, memtable_memory_budget: int) -> None: ... - def optimize_universal_style_compaction(self, memtable_memory_budget: int) -> None: ... + def optimize_universal_style_compaction( + self, memtable_memory_budget: int + ) -> None: ... def prepare_for_bulk_load(self) -> None: ... def set_advise_random_on_open(self, advise: bool) -> None: ... def set_allow_concurrent_memtable_write(self, allow: bool) -> None: ... @@ -191,8 +200,10 @@ class Options: def set_bytes_per_sync(self, nbytes: int) -> None: ... def set_compaction_readahead_size(self, compaction_readahead_size: int) -> None: ... def set_compaction_style(self, style: DBCompactionStyle) -> None: ... - def set_compression_options(self, w_bits: int, level: int, strategy: int, max_dict_bytes: int) -> None: ... - def set_compression_per_level(self,level_types: list) -> None: ... + def set_compression_options( + self, w_bits: int, level: int, strategy: int, max_dict_bytes: int + ) -> None: ... + def set_compression_per_level(self, level_types: list) -> None: ... def set_compression_type(self, t: DBCompressionType) -> None: ... def set_cuckoo_table_factory(self, factory: CuckooTableOptions) -> None: ... def set_db_log_dir(self, path: str) -> None: ... @@ -213,14 +224,16 @@ class Options: def set_keep_log_file_num(self, nfiles: int) -> None: ... def set_level_compaction_dynamic_level_bytes(self, v: bool) -> None: ... def set_level_zero_file_num_compaction_trigger(self, n: int) -> None: ... - def set_level_zero_slowdown_writes_trigger(self, n_int) -> None: ... + def set_level_zero_slowdown_writes_trigger(self, n_int) -> None: ... def set_level_zero_stop_writes_trigger(self, n: int) -> None: ... def set_log_file_time_to_roll(self, secs: int) -> None: ... def set_manifest_preallocation_size(self, size: int) -> None: ... def set_max_background_jobs(self, jobs: int) -> None: ... def set_max_bytes_for_level_base(self, size: int) -> None: ... def set_max_bytes_for_level_multiplier(self, mul: float) -> None: ... - def set_max_bytes_for_level_multiplier_additional(self, level_values: list) -> None: ... + def set_max_bytes_for_level_multiplier_additional( + self, level_values: list + ) -> None: ... def set_max_compaction_bytes(self, nbytes: int) -> None: ... def set_max_file_opening_threads(self, nthreads: int) -> None: ... def set_max_log_file_size(self, size: int) -> None: ... @@ -244,7 +257,9 @@ class Options: def set_paranoid_checks(self, enabled: bool) -> None: ... def set_plain_table_factory(self, options: PlainTableFactoryOptions) -> None: ... def set_prefix_extractor(self, prefix_extractor: SliceTransform) -> None: ... - def set_ratelimiter(self, rate_bytes_per_sec: int, refill_period_us: int, fairness: int) -> None: ... + def set_ratelimiter( + self, rate_bytes_per_sec: int, refill_period_us: int, fairness: int + ) -> None: ... def set_recycle_log_file_num(self, num: int) -> None: ... def set_report_bg_io_stats(self, enable: bool) -> None: ... def set_row_cache(self, cache: Cache) -> None: ... @@ -256,7 +271,9 @@ class Options: def set_table_cache_num_shard_bits(self, nbits: int) -> None: ... def set_target_file_size_base(self, size: int) -> None: ... def set_target_file_size_multiplier(self, multiplier: int) -> None: ... - def set_universal_compaction_options(self, uco: UniversalCompactOptions) -> None: ... + def set_universal_compaction_options( + self, uco: UniversalCompactOptions + ) -> None: ... def set_unordered_write(self, unordered: bool) -> None: ... def set_use_adaptive_mutex(self, enabled: bool) -> None: ... def set_use_direct_io_for_flush_and_compaction(self, enabled: bool) -> None: ... @@ -295,8 +312,12 @@ class ReadOptions: def fill_cache(self) -> None: ... def set_background_purge_on_iterator_cleanup(self, v: bool) -> None: ... def set_ignore_range_deletions(self, v: bool) -> None: ... - def set_iterate_lower_bound(self, key: Union[str, int, float, bytes, bool]) -> None: ... - def set_iterate_upper_bound(self, key: Union[str, int, float, bytes, bool]) -> None: ... + def set_iterate_lower_bound( + self, key: Union[str, int, float, bytes, bool] + ) -> None: ... + def set_iterate_upper_bound( + self, key: Union[str, int, float, bytes, bool] + ) -> None: ... def set_max_skippable_internal_keys(self, num: int) -> None: ... def set_pin_data(self, v: bool) -> None: ... def set_prefix_same_as_start(self, v: bool) -> None: ... @@ -372,10 +393,13 @@ class WriteOptions: def disable_wal(self, disable: bool) -> None: ... class Rdict: - def __init__(self, path: str, - options: Union[Options, None] = None, - column_families: Union[Dict[str, Options], None] = None, - access_type: AccessType = AccessType.read_write()) -> None: ... + def __init__( + self, + path: str, + options: Union[Options, None] = None, + column_families: Union[Dict[str, Options], None] = None, + access_type: AccessType = AccessType.read_write(), + ) -> None: ... def __enter__(self) -> Rdict: ... def set_dumps(self, dumps: Callable[[Any], bytes]) -> None: ... def set_loads(self, dumps: Callable[[bytes], Any]) -> None: ... @@ -383,41 +407,100 @@ class Rdict: def set_write_options(self, write_opt: WriteOptions) -> None: ... def __contains__(self, key: Union[str, int, float, bytes, bool]) -> bool: ... def __delitem__(self, key: Union[str, int, float, bytes, bool]) -> None: ... - def __getitem__(self, key: Union[str, int, float, bytes, bool, List[Union[str, int, float, bytes, bool]]]) -> Any | None: ... - def __setitem__(self, key: Union[str, int, float, bytes, bool], value: Any) -> None: ... - def get(self, - key: Union[str, int, float, bytes, bool, List[Union[str, int, float, bytes, bool]]], - default: Any = None, - read_opt: Union[ReadOptions, None] = None) -> Any | None: ... - def put(self, - key: Union[str, int, float, bytes, bool], - value: Any, - write_opt: Union[WriteOptions, None] = None) -> None: ... - def delete(self, key: Union[str, int, float, bytes, bool], write_opt: Union[WriteOptions, None] = None) -> None: ... - def key_may_exist(self, - key: Union[str, int, float, bytes, bool], - fetch: bool = False, - read_opt = None) -> Union[bool, Tuple[bool, Any]]: ... + def __getitem__( + self, + key: Union[ + str, int, float, bytes, bool, List[Union[str, int, float, bytes, bool]] + ], + ) -> Any | None: ... + def __setitem__( + self, key: Union[str, int, float, bytes, bool], value: Any + ) -> None: ... + def get( + self, + key: Union[ + str, int, float, bytes, bool, List[Union[str, int, float, bytes, bool]] + ], + default: Any = None, + read_opt: Union[ReadOptions, None] = None, + ) -> Any | None: ... + def put( + self, + key: Union[str, int, float, bytes, bool], + value: Any, + write_opt: Union[WriteOptions, None] = None, + ) -> None: ... + def delete( + self, + key: Union[str, int, float, bytes, bool], + write_opt: Union[WriteOptions, None] = None, + ) -> None: ... + def key_may_exist( + self, + key: Union[str, int, float, bytes, bool], + fetch: bool = False, + read_opt=None, + ) -> Union[bool, Tuple[bool, Any]]: ... def iter(self, read_opt: Union[ReadOptions, None] = None) -> RdictIter: ... - def items(self, backwards: bool = False, - from_key: Union[str, int, float, bytes, bool, None] = None, - read_opt: Union[ReadOptions, None] = None) -> RdictItems: ... - def keys(self, backwards: bool = False, - from_key: Union[str, int, float, bytes, bool, None] = None, - read_opt: Union[ReadOptions, None] = None) -> RdictKeys: ... - def values(self, backwards: bool = False, - from_key: Union[str, int, float, bytes, bool, None] = None, - read_opt: Union[ReadOptions, None] = None) -> RdictValues: ... - def ingest_external_file(self, paths: List[str], opts: IngestExternalFileOptions = IngestExternalFileOptions()) -> None: ... + def items( + self, + backwards: bool = False, + from_key: Union[str, int, float, bytes, bool, None] = None, + read_opt: Union[ReadOptions, None] = None, + ) -> RdictItems: ... + def chunked_items( + self, + chunk_size: Optional[int] = None, + backwards: bool = False, + from_key: Union[str, int, float, bytes, bool, None] = None, + read_opt: Union[ReadOptions, None] = None, + ) -> RdictChunkedItems: ... + def keys( + self, + backwards: bool = False, + from_key: Union[str, int, float, bytes, bool, None] = None, + read_opt: Union[ReadOptions, None] = None, + ) -> RdictKeys: ... + def chunked_keys( + self, + chunk_size: Optional[int] = None, + backwards: bool = False, + from_key: Union[str, int, float, bytes, bool, None] = None, + read_opt: Union[ReadOptions, None] = None, + ) -> RdictChunkedKeys: ... + def values( + self, + backwards: bool = False, + from_key: Union[str, int, float, bytes, bool, None] = None, + read_opt: Union[ReadOptions, None] = None, + ) -> RdictValues: ... + def chunked_values( + self, + chunk_size: Optional[int] = None, + backwards: bool = False, + from_key: Union[str, int, float, bytes, bool, None] = None, + read_opt: Union[ReadOptions, None] = None, + ) -> RdictChunkedValues: ... + def ingest_external_file( + self, + paths: List[str], + opts: IngestExternalFileOptions = IngestExternalFileOptions(), + ) -> None: ... def get_column_family(self, name: str) -> Rdict: ... def get_column_family_handle(self, name: str) -> ColumnFamily: ... def drop_column_family(self, name: str) -> None: ... - def create_column_family(self, name: str, options: Options = Options()) -> Rdict: ... - def write(self, write_batch: WriteBatch, write_opt: Union[WriteOptions, None] = None) -> None: ... - def delete_range(self, - begin: Union[str, int, float, bytes, bool], - end: Union[str, int, float, bytes, bool], - write_opt: Union[WriteOptions, None] = None) -> None: ... + def create_column_family( + self, name: str, options: Options = Options() + ) -> Rdict: ... + def write( + self, write_batch: WriteBatch, write_opt: Union[WriteOptions, None] = None + ) -> None: ... + def delete_range( + self, + begin: Union[str, int, float, bytes, bool], + end: Union[str, int, float, bytes, bool], + write_opt: Union[WriteOptions, None] = None, + ) -> None: ... def snapshot(self) -> Snapshot: ... def path(self) -> str: ... def set_options(self, options: Dict[str, str]) -> None: ... @@ -425,9 +508,12 @@ class Rdict: def property_int_value(self, name: str) -> Union[int, None]: ... def latest_sequence_number(self) -> int: ... def live_files(self) -> List[Dict[str, Any]]: ... - def compact_range(self, begin: Union[str, int, float, bytes, bool, None], - end: Union[str, int, float, bytes, bool, None], - compact_opt: CompactOptions = CompactOptions()) -> None: ... + def compact_range( + self, + begin: Union[str, int, float, bytes, bool, None], + end: Union[str, int, float, bytes, bool, None], + compact_opt: CompactOptions = CompactOptions(), + ) -> None: ... def try_catch_up_with_primary(self) -> None: ... def cancel_all_background(self, wait: bool) -> None: ... def close(self) -> None: ... @@ -453,6 +539,20 @@ class RdictValues(Iterator[Any]): def __iter__(self) -> RdictValues: ... def __next__(self) -> Any: ... +class RdictChunkedItems( + Iterator[List[Tuple[Union[str, int, float, bytes, bool]], Any]] +): + def __iter__(self) -> RdictChunkedItems: ... + def __next__(self) -> List[Tuple[Union[str, int, float, bytes, bool]], Any]: ... + +class RdictChunkedKeys(Iterator[List[Union[str, int, float, bytes, bool]]]): + def __iter__(self) -> RdictChunkedKeys: ... + def __next__(self) -> List[Union[str, int, float, bytes, bool]]: ... + +class RdictChunkedValues(Iterator[List[Any]]): + def __iter__(self) -> RdictValues: ... + def __next__(self) -> List[Any]: ... + class RdictIter: def valid(self) -> bool: ... def status(self) -> None: ... @@ -464,6 +564,15 @@ class RdictIter: def prev(self) -> None: ... def key(self) -> Any: ... def value(self) -> Any: ... + def get_chunk_keys( + self, chunk_size: Optional[int] = None, backwards: bool = False + ) -> List[Union[str, int, float, bytes, bool]]: ... + def get_chunk_values( + self, chunk_size: Optional[int] = None, backwards: bool = False + ) -> List[Any]: ... + def get_chunk_items( + self, chunk_size: Optional[int] = None, backwards: bool = False + ) -> List[Tuple[Union[str, int, float, bytes, bool], Any]]: ... class IngestExternalFileOptions: def __init__(self) -> None: ... @@ -479,26 +588,42 @@ class SstFileWriter: def open(self, path: str) -> None: ... def finish(self) -> None: ... def file_size(self) -> int: ... - def __setitem__(self, key: Union[str, int, float, bytes, bool], value: Any) -> None: ... + def __setitem__( + self, key: Union[str, int, float, bytes, bool], value: Any + ) -> None: ... def __delitem__(self, key: Union[str, int, float, bytes, bool]) -> None: ... class WriteBatch: def __init__(self, raw_mode: bool = False) -> None: ... def __len__(self) -> int: ... - def __setitem__(self, key: Union[str, int, float, bytes, bool], value: Any) -> None: ... + def __setitem__( + self, key: Union[str, int, float, bytes, bool], value: Any + ) -> None: ... def __delitem__(self, key: Union[str, int, float, bytes, bool]) -> None: ... def set_dumps(self, dumps: Callable[[Any], bytes]) -> None: ... - def set_default_column_family(self, column_family: Union[ColumnFamily, None]) -> None: ... + def set_default_column_family( + self, column_family: Union[ColumnFamily, None] + ) -> None: ... def len(self) -> int: ... def size_in_bytes(self) -> int: ... def is_empty(self) -> bool: ... - def put(self, key: Union[str, int, float, bytes, bool], value: Any, - column_family: Union[ColumnFamily, None] = None) -> None: ... - def delete(self, key: Union[str, int, float, bytes, bool], - column_family: Union[ColumnFamily, None] = None) -> None: ... - def delete_range(self, begin: Union[str, int, float, bytes, bool], - end: Union[str, int, float, bytes, bool], - column_family: Union[ColumnFamily, None] = None) -> None: ... + def put( + self, + key: Union[str, int, float, bytes, bool], + value: Any, + column_family: Union[ColumnFamily, None] = None, + ) -> None: ... + def delete( + self, + key: Union[str, int, float, bytes, bool], + column_family: Union[ColumnFamily, None] = None, + ) -> None: ... + def delete_range( + self, + begin: Union[str, int, float, bytes, bool], + end: Union[str, int, float, bytes, bool], + column_family: Union[ColumnFamily, None] = None, + ) -> None: ... def clear(self) -> None: ... class ColumnFamily: ... @@ -516,15 +641,24 @@ class AccessType: class Snapshot: def __getitem__(self, key: Union[str, int, float, bytes, bool]) -> Any: ... def iter(self, read_opt: Union[ReadOptions, None] = None) -> RdictIter: ... - def items(self, backwards: bool = False, - from_key: Union[str, int, float, bytes, bool, None] = None, - read_opt: Union[ReadOptions, None] = None) -> RdictItems: ... - def keys(self, backwards: bool = False, - from_key: Union[str, int, float, bytes, bool, None] = None, - read_opt: Union[ReadOptions, None] = None) -> RdictKeys: ... - def values(self, backwards: bool = False, - from_key: Union[str, int, float, bytes, bool, None] = None, - read_opt: Union[ReadOptions, None] = None) -> RdictValues: ... + def items( + self, + backwards: bool = False, + from_key: Union[str, int, float, bytes, bool, None] = None, + read_opt: Union[ReadOptions, None] = None, + ) -> RdictItems: ... + def keys( + self, + backwards: bool = False, + from_key: Union[str, int, float, bytes, bool, None] = None, + read_opt: Union[ReadOptions, None] = None, + ) -> RdictKeys: ... + def values( + self, + backwards: bool = False, + from_key: Union[str, int, float, bytes, bool, None] = None, + read_opt: Union[ReadOptions, None] = None, + ) -> RdictValues: ... class BottommostLevelCompaction: @staticmethod @@ -539,7 +673,9 @@ class BottommostLevelCompaction: class CompactOptions: def __init__(self) -> None: ... def set_exclusive_manual_compaction(self, v: bool) -> None: ... - def set_bottommost_level_compaction(self, lvl: BottommostLevelCompaction) -> None: ... + def set_bottommost_level_compaction( + self, lvl: BottommostLevelCompaction + ) -> None: ... def set_change_level(self, v: bool) -> None: ... def set_target_level(self, lvl: int) -> None: ... diff --git a/src/iter.rs b/src/iter.rs index 5db026a..d957596 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -1,7 +1,7 @@ use crate::db_reference::DbReferenceHolder; use crate::encoder::{decode_value, encode_key}; use crate::exceptions::DbClosedError; -use crate::util::error_message; +use crate::util::{error_message, SendSyncMutPtr}; use crate::{ReadOpt, ReadOptionsPy}; use core::slice; use libc::{c_char, c_uchar, size_t}; @@ -17,7 +17,10 @@ pub(crate) struct RdictIter { /// iterator must keep a reference count of DB to keep DB alive. pub(crate) db: DbReferenceHolder, - pub(crate) inner: *mut librocksdb_sys::rocksdb_iterator_t, + // This is a wrapper around a `*mut rocksdb_iterator_t`. It is wrapped in a `SendSyncMutPtr`, so + // it is the responsibility of any user that sends it across threads to ensure that the thread + // does not outlive this iterator. + pub(crate) inner: SendSyncMutPtr, /// When iterate_upper_bound is set, the inner C iterator keeps a pointer to the upper bound /// inside `_readopts`. Storing this makes sure the upper bound is always alive when the @@ -64,16 +67,22 @@ impl RdictIter { .ok_or_else(|| DbClosedError::new_err("DB instance already closed"))? .inner(); + let inner = unsafe { + match cf { + None => SendSyncMutPtr::new(librocksdb_sys::rocksdb_create_iterator( + db_inner, readopts.0, + )), + Some(cf) => SendSyncMutPtr::new(librocksdb_sys::rocksdb_create_iterator_cf( + db_inner, + readopts.0, + cf.inner(), + )), + } + }; + Ok(RdictIter { db: db.clone(), - inner: unsafe { - match cf { - None => librocksdb_sys::rocksdb_create_iterator(db_inner, readopts.0), - Some(cf) => { - librocksdb_sys::rocksdb_create_iterator_cf(db_inner, readopts.0, cf.inner()) - } - } - }, + inner, readopts, pickle_loads: pickle_loads.clone(), raw_mode, @@ -91,7 +100,7 @@ impl RdictIter { /// return an error when `valid` is `true`. #[inline] pub fn valid(&self) -> bool { - unsafe { librocksdb_sys::rocksdb_iter_valid(self.inner) != 0 } + unsafe { librocksdb_sys::rocksdb_iter_valid(self.inner.get()) != 0 } } /// Returns an error `Result` if the iterator has encountered an error @@ -102,7 +111,7 @@ impl RdictIter { pub fn status(&self) -> PyResult<()> { let mut err: *mut c_char = null_mut(); unsafe { - librocksdb_sys::rocksdb_iter_get_error(self.inner, &mut err); + librocksdb_sys::rocksdb_iter_get_error(self.inner.get(), &mut err); } if !err.is_null() { Err(PyException::new_err(error_message(err))) @@ -137,7 +146,7 @@ impl RdictIter { /// Rdict.destroy(path, Options()) pub fn seek_to_first(&mut self) { unsafe { - librocksdb_sys::rocksdb_iter_seek_to_first(self.inner); + librocksdb_sys::rocksdb_iter_seek_to_first(self.inner.get()); } } @@ -167,7 +176,7 @@ impl RdictIter { /// Rdict.destroy(path, Options()) pub fn seek_to_last(&mut self) { unsafe { - librocksdb_sys::rocksdb_iter_seek_to_last(self.inner); + librocksdb_sys::rocksdb_iter_seek_to_last(self.inner.get()); } } @@ -195,7 +204,7 @@ impl RdictIter { let key = encode_key(key, self.raw_mode)?; unsafe { librocksdb_sys::rocksdb_iter_seek( - self.inner, + self.inner.get(), key.as_ptr() as *const c_char, key.len() as size_t, ); @@ -228,7 +237,7 @@ impl RdictIter { let key = encode_key(key, self.raw_mode)?; unsafe { librocksdb_sys::rocksdb_iter_seek_for_prev( - self.inner, + self.inner.get(), key.as_ptr() as *const c_char, key.len() as size_t, ); @@ -239,14 +248,14 @@ impl RdictIter { /// Seeks to the next key. pub fn next(&mut self) { unsafe { - librocksdb_sys::rocksdb_iter_next(self.inner); + librocksdb_sys::rocksdb_iter_next(self.inner.get()); } } /// Seeks to the previous key. pub fn prev(&mut self) { unsafe { - librocksdb_sys::rocksdb_iter_prev(self.inner); + librocksdb_sys::rocksdb_iter_prev(self.inner.get()); } } @@ -258,8 +267,8 @@ impl RdictIter { unsafe { let mut key_len: size_t = 0; let key_len_ptr: *mut size_t = &mut key_len; - let key_ptr = - librocksdb_sys::rocksdb_iter_key(self.inner, key_len_ptr) as *const c_uchar; + let key_ptr = librocksdb_sys::rocksdb_iter_key(self.inner.get(), key_len_ptr) + as *const c_uchar; let key = slice::from_raw_parts(key_ptr, key_len); Ok(decode_value(py, key, &self.pickle_loads, self.raw_mode)?) } @@ -276,8 +285,8 @@ impl RdictIter { unsafe { let mut val_len: size_t = 0; let val_len_ptr: *mut size_t = &mut val_len; - let val_ptr = - librocksdb_sys::rocksdb_iter_value(self.inner, val_len_ptr) as *const c_uchar; + let val_ptr = librocksdb_sys::rocksdb_iter_value(self.inner.get(), val_len_ptr) + as *const c_uchar; let value = slice::from_raw_parts(val_ptr, val_len); Ok(decode_value(py, value, &self.pickle_loads, self.raw_mode)?) } @@ -285,12 +294,179 @@ impl RdictIter { Ok(py.None()) } } + + /// Returns a chunk of keys from the iterator. + /// + /// This is more efficient than calling the iterator per element and will drop the GIL while + /// fetching the chunk. + /// + /// Args: + /// chunk_size: the number of items to return. If `None`, items will be returned until the + /// iterator is exhausted. + /// backwards: if `True`, iterator will traverse backwards. + #[pyo3(signature = (chunk_size = None, backwards = false))] + pub fn get_chunk_keys( + &mut self, + chunk_size: Option, + backwards: bool, + py: Python, + ) -> PyResult> { + let raw_keys = py.allow_threads(|| { + let mut raw_keys = Vec::new(); + while self.valid() && raw_keys.len() < chunk_size.unwrap_or(usize::MAX) { + // Safety: This is safe for multiple reasons: + // * It makes a copy of the buffer before returning. + // * This `allow_threads` block does not outlive the iterator's lifetime. + let key = unsafe { + let mut key_len: size_t = 0; + let key_len_ptr: *mut size_t = &mut key_len; + let key_ptr = librocksdb_sys::rocksdb_iter_key(self.inner.get(), key_len_ptr) + as *const c_uchar; + slice::from_raw_parts(key_ptr, key_len) + .to_vec() + .into_boxed_slice() + }; + raw_keys.push(key); + + if backwards { + self.prev(); + } else { + self.next(); + } + } + + raw_keys + }); + + raw_keys + .into_iter() + .map(|key| decode_value(py, &key, &self.pickle_loads, self.raw_mode)) + .collect() + } + + /// Returns a chunk of values from the iterator. + /// + /// This is more efficient than calling the iterator per element and will drop the GIL while + /// fetching the chunk. + /// + /// Args: + /// chunk_size: the number of items to return. If `None`, items will be returned until the + /// iterator is exhausted. + /// backwards: if `True`, iterator will traverse backwards. + #[pyo3(signature = (chunk_size = None, backwards = false))] + pub fn get_chunk_values( + &mut self, + chunk_size: Option, + backwards: bool, + py: Python, + ) -> PyResult> { + let raw_values = py.allow_threads(|| { + let mut raw_values = Vec::new(); + while self.valid() && raw_values.len() < chunk_size.unwrap_or(usize::MAX) { + // Safety: This is safe for multiple reasons: + // * It makes a copy of the buffer before returning. + // * This `allow_threads` block does not outlive the iterator's lifetime. + let value = unsafe { + let mut value_len: size_t = 0; + let value_len_ptr: *mut size_t = &mut value_len; + let value_ptr = + librocksdb_sys::rocksdb_iter_value(self.inner.get(), value_len_ptr) + as *const c_uchar; + slice::from_raw_parts(value_ptr, value_len) + .to_vec() + .into_boxed_slice() + }; + raw_values.push(value); + + if backwards { + self.prev(); + } else { + self.next(); + } + } + + raw_values + }); + + raw_values + .into_iter() + .map(|value| decode_value(py, &value, &self.pickle_loads, self.raw_mode)) + .collect() + } + + /// Returns a chunk of key-value pairs from the iterator. + /// + /// This is more efficient than calling the iterator per element and will drop the GIL while + /// fetching the chunk. + /// + /// Args: + /// chunk_size: the number of items to return. If `None`, items will be returned until the + /// iterator is exhausted. + /// backwards: if `True`, iterator will traverse backwards. + #[pyo3(signature = (chunk_size = None, backwards = false))] + pub fn get_chunk_items( + &mut self, + chunk_size: Option, + backwards: bool, + py: Python, + ) -> PyResult> { + let raw_items = py.allow_threads(|| { + let mut raw_items = Vec::new(); + while self.valid() && raw_items.len() < chunk_size.unwrap_or(usize::MAX) { + // Safety: This is safe for multiple reasons: + // * It makes a copy of the buffer before returning. + // * This `allow_threads` block does not outlive the iterator's lifetime. + let key = unsafe { + let mut key_len: size_t = 0; + let key_len_ptr: *mut size_t = &mut key_len; + let key_ptr = librocksdb_sys::rocksdb_iter_key(self.inner.get(), key_len_ptr) + as *const c_uchar; + slice::from_raw_parts(key_ptr, key_len) + .to_vec() + .into_boxed_slice() + }; + + // Safety: This is safe for multiple reasons: + // * It makes a copy of the buffer before returning. + // * This `allow_threads` block does not outlive the iterator's lifetime. + let value = unsafe { + let mut value_len: size_t = 0; + let value_len_ptr: *mut size_t = &mut value_len; + let value_ptr = + librocksdb_sys::rocksdb_iter_value(self.inner.get(), value_len_ptr) + as *const c_uchar; + slice::from_raw_parts(value_ptr, value_len) + .to_vec() + .into_boxed_slice() + }; + + raw_items.push((key, value)); + + if backwards { + self.prev(); + } else { + self.next(); + } + } + + raw_items + }); + + raw_items + .into_iter() + .map(|(key, value)| { + let key = decode_value(py, &key, &self.pickle_loads, self.raw_mode)?; + let value = decode_value(py, &value, &self.pickle_loads, self.raw_mode)?; + Ok((key, value)) + }) + .collect() + } } impl Drop for RdictIter { fn drop(&mut self) { unsafe { - librocksdb_sys::rocksdb_iter_destroy(self.inner); + librocksdb_sys::rocksdb_iter_destroy(self.inner.get()); } } } @@ -345,6 +521,69 @@ macro_rules! impl_iter { }; } +macro_rules! impl_chunked_iter { + ($iter_name: ident, $iter_chunk_fn: ident) => { + #[pyclass] + pub(crate) struct $iter_name { + inner: RdictIter, + backwards: bool, + chunk_size: Option, + } + + #[pymethods] + impl $iter_name { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + + fn __next__(&mut self, py: Python) -> PyResult> { + if self.inner.valid() { + Ok(Some( + self.inner + .$iter_chunk_fn(self.chunk_size, self.backwards, py) + .map(|v| v.to_object(py))?, + )) + } else { + Ok(None) + } + } + } + + impl $iter_name { + pub(crate) fn new( + inner: RdictIter, + chunk_size: Option, + backwards: bool, + from_key: Option<&PyAny>, + ) -> PyResult { + let mut inner = inner; + if let Some(from_key) = from_key { + if backwards { + inner.seek_for_prev(from_key)?; + } else { + inner.seek(from_key)?; + } + } else { + if backwards { + inner.seek_to_last(); + } else { + inner.seek_to_first(); + } + } + Ok(Self { + inner, + backwards, + chunk_size, + }) + } + } + }; +} + impl_iter!(RdictKeys, key); impl_iter!(RdictValues, value); impl_iter!(RdictItems, key, value); + +impl_chunked_iter!(RdictChunkedKeys, get_chunk_keys); +impl_chunked_iter!(RdictChunkedValues, get_chunk_values); +impl_chunked_iter!(RdictChunkedItems, get_chunk_items); diff --git a/src/rdict.rs b/src/rdict.rs index 781b32b..9342f74 100644 --- a/src/rdict.rs +++ b/src/rdict.rs @@ -1,7 +1,9 @@ use crate::db_reference::{DbReference, DbReferenceHolder}; use crate::encoder::{decode_value, encode_key, encode_value}; use crate::exceptions::DbClosedError; -use crate::iter::{RdictItems, RdictKeys, RdictValues}; +use crate::iter::{ + RdictChunkedItems, RdictChunkedKeys, RdictChunkedValues, RdictItems, RdictKeys, RdictValues, +}; use crate::options::{CachePy, EnvPy, SliceTransformType}; use crate::{ CompactOptionsPy, FlushOptionsPy, IngestExternalFileOptionsPy, OptionsPy, RdictIter, @@ -624,7 +626,7 @@ impl Rdict { ) } - /// Iterate through all keys and values pairs. + /// Creates an iterator through the items. /// /// Examples: /// :: @@ -649,7 +651,39 @@ impl Rdict { RdictItems::new(self.iter(read_opt, py)?, backwards, from_key) } - /// Iterate through all keys + /// Creates a chunked iterator through the items. + /// + /// This is more efficient than a normal per-element iterator + /// and will drop the GIL while fetching a chunk. + /// + /// Examples: + /// :: + /// + /// for chunk in db.chunked_items(chunk_size=1000): + /// for k, v in chunk: + /// print(f"{k} -> {v}") + /// + /// Args: + /// chunk_size: the number of items to return. If None, + /// returns all items in one chunk. + /// backwards: iteration direction, forward if `False`. + /// from_key: iterate from key, first seek to this key + /// or the nearest next key for iteration + /// (depending on iteration direction). + /// read_opt: ReadOptions + #[pyo3(signature = (chunk_size = None, backwards = false, from_key = None, read_opt = None))] + fn chunked_items( + &self, + chunk_size: Option, + backwards: bool, + from_key: Option<&PyAny>, + read_opt: Option<&ReadOptionsPy>, + py: Python, + ) -> PyResult { + RdictChunkedItems::new(self.iter(read_opt, py)?, chunk_size, backwards, from_key) + } + + /// Creates an iterator through the keys. /// /// Examples: /// :: @@ -673,7 +707,38 @@ impl Rdict { RdictKeys::new(self.iter(read_opt, py)?, backwards, from_key) } - /// Iterate through all values. + /// Creates a chunked iterator through the keys. + /// + /// This is more efficient than a normal per-element iterator + /// and will drop the GIL while fetching a chunk. + /// + /// Examples: + /// :: + /// + /// for chunk in db.chunked_keys(chunk_size=1000): + /// print(", ".join(chunk)) + /// + /// Args: + /// chunk_size: the number of items to return. If None, + /// returns all items in one chunk. + /// backwards: iteration direction, forward if `False`. + /// from_key: iterate from key, first seek to this key + /// or the nearest next key for iteration + /// (depending on iteration direction). + /// read_opt: ReadOptions + #[pyo3(signature = (chunk_size = None, backwards = false, from_key = None, read_opt = None))] + fn chunked_keys( + &self, + chunk_size: Option, + backwards: bool, + from_key: Option<&PyAny>, + read_opt: Option<&ReadOptionsPy>, + py: Python, + ) -> PyResult { + RdictChunkedKeys::new(self.iter(read_opt, py)?, chunk_size, backwards, from_key) + } + + /// Creates an iterator through the values. /// /// Examples: /// :: @@ -697,6 +762,37 @@ impl Rdict { RdictValues::new(self.iter(read_opt, py)?, backwards, from_key) } + /// Creates a chunked iterator through the values. + /// + /// This is more efficient than a normal per-element iterator + /// and will drop the GIL while fetching a chunk. + /// + /// Examples: + /// :: + /// + /// for chunk in db.chunked_values(chunk_size=1000): + /// print(", ".join(chunk)) + /// + /// Args: + /// chunk_size: the number of items to return. If None, + /// returns all items in one chunk. + /// backwards: iteration direction, forward if `False`. + /// from_key: iterate from key, first seek to this key + /// or the nearest next key for iteration + /// (depending on iteration direction). + /// read_opt: ReadOptions + #[pyo3(signature = (chunk_size = None, backwards = false, from_key = None, read_opt = None))] + fn chunked_values( + &self, + chunk_size: Option, + backwards: bool, + from_key: Option<&PyAny>, + read_opt: Option<&ReadOptionsPy>, + py: Python, + ) -> PyResult { + RdictChunkedValues::new(self.iter(read_opt, py)?, chunk_size, backwards, from_key) + } + /// Manually flush the current column family. /// /// Notes: diff --git a/src/util.rs b/src/util.rs index 73aaee0..f3495bd 100644 --- a/src/util.rs +++ b/src/util.rs @@ -25,3 +25,22 @@ pub(crate) fn to_cpath>(path: P) -> PyResult { ))), } } + +/// Wrapper around a raw pointer that is safe to send across threads. The user is responsible for +/// ensuring that the pointer is valid for the lifetime of the thread. +pub(crate) struct SendSyncMutPtr { + ptr: *mut T, +} + +unsafe impl Send for SendSyncMutPtr {} +unsafe impl Sync for SendSyncMutPtr {} + +impl SendSyncMutPtr { + pub(crate) unsafe fn new(ptr: *mut T) -> Self { + Self { ptr } + } + + pub(crate) unsafe fn get(&self) -> *mut T { + self.ptr + } +} diff --git a/test/bench_rdict.py b/test/bench_rdict.py index 8daf36e..d057a75 100644 --- a/test/bench_rdict.py +++ b/test/bench_rdict.py @@ -84,7 +84,13 @@ def perf_iterator_single_thread(rand_bytes: List[bytes]): count += 1 end = time.perf_counter() assert count == len(rand_bytes) - print("Iterator performance: {} items in {} seconds".format(count, end - start)) + + num_items = count + secs = end - start + item_per_sec = num_items / secs + print( + f"Iterator performance: {num_items} items in {secs} seconds ({item_per_sec} it/s)" + ) rdict.close() @@ -107,10 +113,61 @@ def perf_iter(): for t in threads: t.join() end = time.perf_counter() + + num_items = num_threads * len(rand_bytes) + secs = end - start + item_per_sec = num_items / secs print( - "Iterator performance multi-thread: {} items in {} seconds".format( - num_threads * len(rand_bytes), end - start - ) + f"Iterator performance multi-thread: {num_items} items in {secs} seconds ({item_per_sec} it/s)" + ) + rdict.close() + + +def perf_iterator_chunk_single_thread(rand_bytes: List[bytes], chunk_size: int): + rdict = Rdict("test.db", Options(raw_mode=True)) + items = [] + + start = time.perf_counter() + for batch in rdict.chunked_items(chunk_size=chunk_size): + items.extend(batch) + end = time.perf_counter() + + num_items = len(items) + secs = end - start + item_per_sec = num_items / secs + print( + f"Batched iterator performance: {num_items} items in {secs} seconds ({item_per_sec} it/s)" + ) + rdict.close() + + +def perf_iterator_chunk_multi_thread( + rand_bytes: List[bytes], num_threads: int, batch_size: int +): + rdict = Rdict("test.db", Options(raw_mode=True)) + start = time.perf_counter() + + def perf_iter(): + items = [] + for batch in rdict.chunked_items(batch_size): + items.extend(batch) + + assert len(items) == len(rand_bytes) + + threads = [] + for _ in range(num_threads): + t = Thread(target=perf_iter) + t.start() + threads.append(t) + for t in threads: + t.join() + end = time.perf_counter() + + num_items = num_threads * len(rand_bytes) + secs = end - start + item_per_sec = num_items / secs + print( + f"Batched iterator performance multi-thread: {num_items} items in {secs} seconds ({item_per_sec} it/s)" ) rdict.close() @@ -161,7 +218,9 @@ def perf_get(keys: List[bytes]): rand_bytes = gen_rand_bytes() NUM_THREADS = 4 + ITER_CHUNK_SIZE = 25_000 + print() print("Benchmarking Rdict Put...") # perf write perf_put_single_thread(rand_bytes) @@ -172,9 +231,20 @@ def perf_get(keys: List[bytes]): for b in rand_bytes: rdict[b] = b rdict.close() + + print() print("Benchmarking Rdict Iterator...") perf_iterator_single_thread(rand_bytes) perf_iterator_multi_thread(rand_bytes, num_threads=NUM_THREADS) + + print() + print("Benchmarking Rdict Batch Iterator...") + perf_iterator_chunk_single_thread(rand_bytes, chunk_size=ITER_CHUNK_SIZE) + perf_iterator_chunk_multi_thread( + rand_bytes, num_threads=NUM_THREADS, batch_size=ITER_CHUNK_SIZE + ) + + print() print("Benchmarking Rdict Get...") perf_random_get_single_thread(rand_bytes) perf_random_get_multi_thread(rand_bytes, num_threads=NUM_THREADS) diff --git a/test/test_rdict.py b/test/test_rdict.py index a7527c1..9740c5e 100644 --- a/test/test_rdict.py +++ b/test/test_rdict.py @@ -213,6 +213,146 @@ def test_seek_backward_key(self): [k for k in self.test_dict.keys(from_key=key, backwards=True)], ref_list ) + def test_chunk_keys_forward(self): + it = self.test_dict.iter() + it.seek_to_first() + test_list = it.get_chunk_keys() + + ref_list = [k for k in self.ref_dict.keys()] + ref_list.sort() + + self.assertEqual(test_list, ref_list) + + def test_chunk_keys_backward(self): + it = self.test_dict.iter() + it.seek_to_last() + test_list = it.get_chunk_keys(backwards=True) + + ref_list = [k for k in self.ref_dict.keys()] + ref_list.sort(reverse=True) + + self.assertEqual(test_list, ref_list) + + def test_chunk_keys_forward_with_count(self): + CHUNK_SIZE = 33 + test_list = [] + + for chunk in self.test_dict.chunked_keys(chunk_size=CHUNK_SIZE): + test_list.extend(chunk) + + ref_list = [k for k in self.ref_dict.keys()] + ref_list.sort() + + self.assertEqual(test_list, ref_list) + + def test_chunk_keys_backward_with_count(self): + CHUNK_SIZE = 33 + test_list = [] + + for chunk in self.test_dict.chunked_keys(chunk_size=CHUNK_SIZE, backwards=True): + test_list.extend(chunk) + + ref_list = [k for k in self.ref_dict.keys()] + ref_list.sort(reverse=True) + + self.assertEqual(test_list, ref_list) + + def test_chunk_items_forward(self): + it = self.test_dict.iter() + it.seek_to_first() + test_list = it.get_chunk_items() + + ref_list = [k for k in self.ref_dict.items()] + ref_list.sort() + + self.assertEqual(test_list, ref_list) + + def test_chunk_items_backward(self): + it = self.test_dict.iter() + it.seek_to_last() + test_list = it.get_chunk_items(backwards=True) + + ref_list = [k for k in self.ref_dict.items()] + ref_list.sort(reverse=True) + + self.assertEqual(test_list, ref_list) + + def test_chunk_items_forward_with_count(self): + CHUNK_SIZE = 33 + test_list = [] + + for chunk in self.test_dict.chunked_items(chunk_size=CHUNK_SIZE): + test_list.extend(chunk) + + ref_list = [k for k in self.ref_dict.items()] + ref_list.sort() + + self.assertEqual(test_list, ref_list) + + def test_chunk_items_backward_with_count(self): + CHUNK_SIZE = 33 + test_list = [] + + for chunk in self.test_dict.chunked_items( + chunk_size=CHUNK_SIZE, backwards=True + ): + test_list.extend(chunk) + + ref_list = [k for k in self.ref_dict.items()] + ref_list.sort(reverse=True) + + self.assertEqual(test_list, ref_list) + + def test_chunk_values_forward(self): + it = self.test_dict.iter() + it.seek_to_first() + test_list = it.get_chunk_values() + + ref_list = list(self.ref_dict.items()) + ref_list.sort() + ref_list = [v for _, v in ref_list] + + self.assertEqual(test_list, ref_list) + + def test_chunk_values_backward(self): + it = self.test_dict.iter() + it.seek_to_last() + test_list = it.get_chunk_values(backwards=True) + + ref_list = list(self.ref_dict.items()) + ref_list.sort(reverse=True) + ref_list = [v for _, v in ref_list] + + self.assertEqual(test_list, ref_list) + + def test_chunk_values_forward_with_count(self): + CHUNK_SIZE = 33 + test_list = [] + + for chunk in self.test_dict.chunked_values(chunk_size=CHUNK_SIZE): + test_list.extend(chunk) + + ref_list = list(self.ref_dict.items()) + ref_list.sort() + ref_list = [v for _, v in ref_list] + + self.assertEqual(test_list, ref_list) + + def test_chunk_values_backward_with_count(self): + CHUNK_SIZE = 33 + test_list = [] + + for chunk in self.test_dict.chunked_values( + chunk_size=CHUNK_SIZE, backwards=True + ): + test_list.extend(chunk) + + ref_list = list(self.ref_dict.items()) + ref_list.sort(reverse=True) + ref_list = [v for _, v in ref_list] + + self.assertEqual(test_list, ref_list) + @classmethod def tearDownClass(cls): cls.test_dict.close() From 995951a17aa86446e92d09bc304a3390dd1c4ae7 Mon Sep 17 00:00:00 2001 From: Chris Tam Date: Sat, 16 Dec 2023 12:13:44 -0500 Subject: [PATCH 2/4] Wrap iterator in lock to prevent races --- src/iter.rs | 208 ++++++++++++++++++++++++++++++++++------------------ src/util.rs | 7 +- 2 files changed, 140 insertions(+), 75 deletions(-) diff --git a/src/iter.rs b/src/iter.rs index d957596..bc1e47a 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -1,15 +1,16 @@ use crate::db_reference::DbReferenceHolder; use crate::encoder::{decode_value, encode_key}; use crate::exceptions::DbClosedError; -use crate::util::{error_message, SendSyncMutPtr}; +use crate::util::{error_message, SendMutPtr}; use crate::{ReadOpt, ReadOptionsPy}; use core::slice; use libc::{c_char, c_uchar, size_t}; -use pyo3::exceptions::PyException; +use pyo3::exceptions::{PyException, PyRuntimeError}; use pyo3::prelude::*; use rocksdb::{AsColumnFamilyRef, UnboundColumnFamily}; +use std::ops::Deref; use std::ptr::null_mut; -use std::sync::Arc; +use std::sync::{Arc, Mutex, MutexGuard}; #[pyclass] #[allow(dead_code)] @@ -17,10 +18,9 @@ pub(crate) struct RdictIter { /// iterator must keep a reference count of DB to keep DB alive. pub(crate) db: DbReferenceHolder, - // This is a wrapper around a `*mut rocksdb_iterator_t`. It is wrapped in a `SendSyncMutPtr`, so - // it is the responsibility of any user that sends it across threads to ensure that the thread - // does not outlive this iterator. - pub(crate) inner: SendSyncMutPtr, + // This is wrapped in a lock, since this iterator can theoretically be shared between Python + // threads. + pub(crate) inner: Mutex>, /// When iterate_upper_bound is set, the inner C iterator keeps a pointer to the upper bound /// inside `_readopts`. Storing this makes sure the upper bound is always alive when the @@ -69,10 +69,10 @@ impl RdictIter { let inner = unsafe { match cf { - None => SendSyncMutPtr::new(librocksdb_sys::rocksdb_create_iterator( + None => SendMutPtr::new(librocksdb_sys::rocksdb_create_iterator( db_inner, readopts.0, )), - Some(cf) => SendSyncMutPtr::new(librocksdb_sys::rocksdb_create_iterator_cf( + Some(cf) => SendMutPtr::new(librocksdb_sys::rocksdb_create_iterator_cf( db_inner, readopts.0, cf.inner(), @@ -82,12 +82,45 @@ impl RdictIter { Ok(RdictIter { db: db.clone(), - inner, + inner: Mutex::new(inner), readopts, pickle_loads: pickle_loads.clone(), raw_mode, }) } + + fn is_valid_locked( + &self, + inner_locked: &MutexGuard<'_, SendMutPtr>, + ) -> bool { + unsafe { librocksdb_sys::rocksdb_iter_valid(inner_locked.deref().get()) != 0 } + } + + fn prev_locked( + &self, + inner_locked: &MutexGuard<'_, SendMutPtr>, + ) { + unsafe { + librocksdb_sys::rocksdb_iter_prev(inner_locked.deref().get()); + } + } + + fn next_locked( + &self, + inner_locked: &MutexGuard<'_, SendMutPtr>, + ) { + unsafe { + librocksdb_sys::rocksdb_iter_next(inner_locked.deref().get()); + } + } + + fn get_inner_locked( + &self, + ) -> PyResult>> { + self.inner + .lock() + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + } } #[pymethods] @@ -99,8 +132,9 @@ impl RdictIter { /// returned `false`, use the [`status`](DBRawIteratorWithThreadMode::status) method. `status` will never /// return an error when `valid` is `true`. #[inline] - pub fn valid(&self) -> bool { - unsafe { librocksdb_sys::rocksdb_iter_valid(self.inner.get()) != 0 } + pub fn valid(&self) -> PyResult { + let inner_locked = self.get_inner_locked()?; + Ok(self.is_valid_locked(&inner_locked)) } /// Returns an error `Result` if the iterator has encountered an error @@ -110,8 +144,9 @@ impl RdictIter { /// Performing a seek will discard the current status. pub fn status(&self) -> PyResult<()> { let mut err: *mut c_char = null_mut(); + let inner_locked = self.get_inner_locked()?; unsafe { - librocksdb_sys::rocksdb_iter_get_error(self.inner.get(), &mut err); + librocksdb_sys::rocksdb_iter_get_error(inner_locked.deref().get(), &mut err); } if !err.is_null() { Err(PyException::new_err(error_message(err))) @@ -144,10 +179,13 @@ impl RdictIter { /// /// del iter, db /// Rdict.destroy(path, Options()) - pub fn seek_to_first(&mut self) { + pub fn seek_to_first(&mut self) -> PyResult<()> { + let inner_locked = self.get_inner_locked()?; unsafe { - librocksdb_sys::rocksdb_iter_seek_to_first(self.inner.get()); + librocksdb_sys::rocksdb_iter_seek_to_first(inner_locked.deref().get()); } + + Ok(()) } /// Seeks to the last key in the database. @@ -174,10 +212,13 @@ impl RdictIter { /// /// del iter, db /// Rdict.destroy(path, Options()) - pub fn seek_to_last(&mut self) { + pub fn seek_to_last(&mut self) -> PyResult<()> { + let inner_locked = self.get_inner_locked()?; unsafe { - librocksdb_sys::rocksdb_iter_seek_to_last(self.inner.get()); + librocksdb_sys::rocksdb_iter_seek_to_last(inner_locked.deref().get()); } + + Ok(()) } /// Seeks to the specified key or the first key that lexicographically follows it. @@ -202,9 +243,11 @@ impl RdictIter { /// Rdict.destroy(path, Options()) pub fn seek(&mut self, key: &PyAny) -> PyResult<()> { let key = encode_key(key, self.raw_mode)?; + + let inner_locked = self.get_inner_locked()?; unsafe { librocksdb_sys::rocksdb_iter_seek( - self.inner.get(), + inner_locked.deref().get(), key.as_ptr() as *const c_char, key.len() as size_t, ); @@ -235,9 +278,10 @@ impl RdictIter { /// Rdict.destroy(path, Options()) pub fn seek_for_prev(&mut self, key: &PyAny) -> PyResult<()> { let key = encode_key(key, self.raw_mode)?; + let inner_locked = self.get_inner_locked()?; unsafe { librocksdb_sys::rocksdb_iter_seek_for_prev( - self.inner.get(), + inner_locked.deref().get(), key.as_ptr() as *const c_char, key.len() as size_t, ); @@ -246,29 +290,33 @@ impl RdictIter { } /// Seeks to the next key. - pub fn next(&mut self) { - unsafe { - librocksdb_sys::rocksdb_iter_next(self.inner.get()); - } + pub fn next(&mut self) -> PyResult<()> { + let inner_locked = self.get_inner_locked()?; + self.next_locked(&inner_locked); + Ok(()) } /// Seeks to the previous key. - pub fn prev(&mut self) { - unsafe { - librocksdb_sys::rocksdb_iter_prev(self.inner.get()); - } + pub fn prev(&mut self) -> PyResult<()> { + let inner_locked = self.get_inner_locked()?; + self.prev_locked(&inner_locked); + Ok(()) } /// Returns the current key. pub fn key(&self, py: Python) -> PyResult { - if self.valid() { + let inner_locked = self.get_inner_locked()?; + if self.is_valid_locked(&inner_locked) { + let inner_locked = self.get_inner_locked()?; + // Safety Note: This is safe as all methods that may invalidate the buffer returned // take `&mut self`, so borrow checker will prevent use of buffer after seek. unsafe { let mut key_len: size_t = 0; let key_len_ptr: *mut size_t = &mut key_len; - let key_ptr = librocksdb_sys::rocksdb_iter_key(self.inner.get(), key_len_ptr) - as *const c_uchar; + let key_ptr = + librocksdb_sys::rocksdb_iter_key(inner_locked.deref().get(), key_len_ptr) + as *const c_uchar; let key = slice::from_raw_parts(key_ptr, key_len); Ok(decode_value(py, key, &self.pickle_loads, self.raw_mode)?) } @@ -279,14 +327,16 @@ impl RdictIter { /// Returns the current value. pub fn value(&self, py: Python) -> PyResult { - if self.valid() { + let inner_locked = self.get_inner_locked()?; + if self.is_valid_locked(&inner_locked) { // Safety Note: This is safe as all methods that may invalidate the buffer returned // take `&mut self`, so borrow checker will prevent use of buffer after seek. unsafe { let mut val_len: size_t = 0; let val_len_ptr: *mut size_t = &mut val_len; - let val_ptr = librocksdb_sys::rocksdb_iter_value(self.inner.get(), val_len_ptr) - as *const c_uchar; + let val_ptr = + librocksdb_sys::rocksdb_iter_value(inner_locked.deref().get(), val_len_ptr) + as *const c_uchar; let value = slice::from_raw_parts(val_ptr, val_len); Ok(decode_value(py, value, &self.pickle_loads, self.raw_mode)?) } @@ -311,17 +361,22 @@ impl RdictIter { backwards: bool, py: Python, ) -> PyResult> { - let raw_keys = py.allow_threads(|| { + let raw_keys = py.allow_threads(|| -> PyResult>> { let mut raw_keys = Vec::new(); - while self.valid() && raw_keys.len() < chunk_size.unwrap_or(usize::MAX) { + let inner_locked = self.get_inner_locked()?; + + while self.is_valid_locked(&inner_locked) + && raw_keys.len() < chunk_size.unwrap_or(usize::MAX) + { // Safety: This is safe for multiple reasons: // * It makes a copy of the buffer before returning. // * This `allow_threads` block does not outlive the iterator's lifetime. let key = unsafe { let mut key_len: size_t = 0; let key_len_ptr: *mut size_t = &mut key_len; - let key_ptr = librocksdb_sys::rocksdb_iter_key(self.inner.get(), key_len_ptr) - as *const c_uchar; + let key_ptr = + librocksdb_sys::rocksdb_iter_key(inner_locked.deref().get(), key_len_ptr) + as *const c_uchar; slice::from_raw_parts(key_ptr, key_len) .to_vec() .into_boxed_slice() @@ -329,14 +384,14 @@ impl RdictIter { raw_keys.push(key); if backwards { - self.prev(); + self.prev_locked(&inner_locked); } else { - self.next(); + self.next_locked(&inner_locked); } } - raw_keys - }); + Ok(raw_keys) + })?; raw_keys .into_iter() @@ -360,18 +415,22 @@ impl RdictIter { backwards: bool, py: Python, ) -> PyResult> { - let raw_values = py.allow_threads(|| { + let raw_values = py.allow_threads(|| -> PyResult>> { let mut raw_values = Vec::new(); - while self.valid() && raw_values.len() < chunk_size.unwrap_or(usize::MAX) { + let inner_locked = self.get_inner_locked()?; + while self.is_valid_locked(&inner_locked) + && raw_values.len() < chunk_size.unwrap_or(usize::MAX) + { // Safety: This is safe for multiple reasons: // * It makes a copy of the buffer before returning. // * This `allow_threads` block does not outlive the iterator's lifetime. let value = unsafe { let mut value_len: size_t = 0; let value_len_ptr: *mut size_t = &mut value_len; - let value_ptr = - librocksdb_sys::rocksdb_iter_value(self.inner.get(), value_len_ptr) - as *const c_uchar; + let value_ptr = librocksdb_sys::rocksdb_iter_value( + inner_locked.deref().get(), + value_len_ptr, + ) as *const c_uchar; slice::from_raw_parts(value_ptr, value_len) .to_vec() .into_boxed_slice() @@ -379,14 +438,14 @@ impl RdictIter { raw_values.push(value); if backwards { - self.prev(); + self.prev_locked(&inner_locked) } else { - self.next(); + self.next_locked(&inner_locked); } } - raw_values - }); + Ok(raw_values) + })?; raw_values .into_iter() @@ -410,17 +469,21 @@ impl RdictIter { backwards: bool, py: Python, ) -> PyResult> { - let raw_items = py.allow_threads(|| { + let raw_items = py.allow_threads(|| -> PyResult, Box<[u8]>)>> { let mut raw_items = Vec::new(); - while self.valid() && raw_items.len() < chunk_size.unwrap_or(usize::MAX) { + let inner_locked = self.get_inner_locked()?; + while self.is_valid_locked(&inner_locked) + && raw_items.len() < chunk_size.unwrap_or(usize::MAX) + { // Safety: This is safe for multiple reasons: // * It makes a copy of the buffer before returning. // * This `allow_threads` block does not outlive the iterator's lifetime. let key = unsafe { let mut key_len: size_t = 0; let key_len_ptr: *mut size_t = &mut key_len; - let key_ptr = librocksdb_sys::rocksdb_iter_key(self.inner.get(), key_len_ptr) - as *const c_uchar; + let key_ptr = + librocksdb_sys::rocksdb_iter_key(inner_locked.deref().get(), key_len_ptr) + as *const c_uchar; slice::from_raw_parts(key_ptr, key_len) .to_vec() .into_boxed_slice() @@ -432,9 +495,10 @@ impl RdictIter { let value = unsafe { let mut value_len: size_t = 0; let value_len_ptr: *mut size_t = &mut value_len; - let value_ptr = - librocksdb_sys::rocksdb_iter_value(self.inner.get(), value_len_ptr) - as *const c_uchar; + let value_ptr = librocksdb_sys::rocksdb_iter_value( + inner_locked.deref().get(), + value_len_ptr, + ) as *const c_uchar; slice::from_raw_parts(value_ptr, value_len) .to_vec() .into_boxed_slice() @@ -443,14 +507,14 @@ impl RdictIter { raw_items.push((key, value)); if backwards { - self.prev(); + self.prev_locked(&inner_locked); } else { - self.next(); + self.next_locked(&inner_locked); } } - raw_items - }); + Ok(raw_items) + })?; raw_items .into_iter() @@ -465,8 +529,10 @@ impl RdictIter { impl Drop for RdictIter { fn drop(&mut self) { - unsafe { - librocksdb_sys::rocksdb_iter_destroy(self.inner.get()); + if let Ok(inner_locked) = self.get_inner_locked() { + unsafe { + librocksdb_sys::rocksdb_iter_destroy(inner_locked.deref().get()); + } } } } @@ -482,12 +548,12 @@ macro_rules! impl_iter { } fn __next__(mut slf: PyRefMut, py: Python) -> PyResult> { - if slf.inner.valid() { + if slf.inner.valid()? { $(let $field = slf.inner.$field(py)?;)* if slf.backwards { - slf.inner.prev(); + slf.inner.prev()?; } else { - slf.inner.next(); + slf.inner.next()?; } Ok(Some(($($field),*).to_object(py))) } else { @@ -507,9 +573,9 @@ macro_rules! impl_iter { } } else { if backwards { - inner.seek_to_last(); + inner.seek_to_last()?; } else { - inner.seek_to_first(); + inner.seek_to_first()?; } } Ok(Self { @@ -537,7 +603,7 @@ macro_rules! impl_chunked_iter { } fn __next__(&mut self, py: Python) -> PyResult> { - if self.inner.valid() { + if self.inner.valid()? { Ok(Some( self.inner .$iter_chunk_fn(self.chunk_size, self.backwards, py) @@ -565,9 +631,9 @@ macro_rules! impl_chunked_iter { } } else { if backwards { - inner.seek_to_last(); + inner.seek_to_last()?; } else { - inner.seek_to_first(); + inner.seek_to_first()?; } } Ok(Self { diff --git a/src/util.rs b/src/util.rs index f3495bd..af3419d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -28,14 +28,13 @@ pub(crate) fn to_cpath>(path: P) -> PyResult { /// Wrapper around a raw pointer that is safe to send across threads. The user is responsible for /// ensuring that the pointer is valid for the lifetime of the thread. -pub(crate) struct SendSyncMutPtr { +pub(crate) struct SendMutPtr { ptr: *mut T, } -unsafe impl Send for SendSyncMutPtr {} -unsafe impl Sync for SendSyncMutPtr {} +unsafe impl Send for SendMutPtr {} -impl SendSyncMutPtr { +impl SendMutPtr { pub(crate) unsafe fn new(ptr: *mut T) -> Self { Self { ptr } } From 8ef38e9e782dca7981139e7d0c45c40476afbde5 Mon Sep 17 00:00:00 2001 From: Chris Tam Date: Sat, 16 Dec 2023 12:29:40 -0500 Subject: [PATCH 3/4] Set default chunk size to 10000 --- python/rocksdict/rocksdict.pyi | 12 ++++++------ src/rdict.rs | 6 +++--- test/bench_rdict.py | 2 +- test/test_rdict.py | 12 ++++++------ 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/rocksdict/rocksdict.pyi b/python/rocksdict/rocksdict.pyi index 9c3016f..a3ea719 100644 --- a/python/rocksdict/rocksdict.pyi +++ b/python/rocksdict/rocksdict.pyi @@ -450,7 +450,7 @@ class Rdict: ) -> RdictItems: ... def chunked_items( self, - chunk_size: Optional[int] = None, + chunk_size: Optional[int] = 10_000, backwards: bool = False, from_key: Union[str, int, float, bytes, bool, None] = None, read_opt: Union[ReadOptions, None] = None, @@ -463,7 +463,7 @@ class Rdict: ) -> RdictKeys: ... def chunked_keys( self, - chunk_size: Optional[int] = None, + chunk_size: Optional[int] = 10_000, backwards: bool = False, from_key: Union[str, int, float, bytes, bool, None] = None, read_opt: Union[ReadOptions, None] = None, @@ -476,7 +476,7 @@ class Rdict: ) -> RdictValues: ... def chunked_values( self, - chunk_size: Optional[int] = None, + chunk_size: Optional[int] = 10_000, backwards: bool = False, from_key: Union[str, int, float, bytes, bool, None] = None, read_opt: Union[ReadOptions, None] = None, @@ -565,13 +565,13 @@ class RdictIter: def key(self) -> Any: ... def value(self) -> Any: ... def get_chunk_keys( - self, chunk_size: Optional[int] = None, backwards: bool = False + self, chunk_size: Optional[int] = 10_000, backwards: bool = False ) -> List[Union[str, int, float, bytes, bool]]: ... def get_chunk_values( - self, chunk_size: Optional[int] = None, backwards: bool = False + self, chunk_size: Optional[int] = 10_000, backwards: bool = False ) -> List[Any]: ... def get_chunk_items( - self, chunk_size: Optional[int] = None, backwards: bool = False + self, chunk_size: Optional[int] = 10_000, backwards: bool = False ) -> List[Tuple[Union[str, int, float, bytes, bool], Any]]: ... class IngestExternalFileOptions: diff --git a/src/rdict.rs b/src/rdict.rs index 9342f74..3387657 100644 --- a/src/rdict.rs +++ b/src/rdict.rs @@ -671,7 +671,7 @@ impl Rdict { /// or the nearest next key for iteration /// (depending on iteration direction). /// read_opt: ReadOptions - #[pyo3(signature = (chunk_size = None, backwards = false, from_key = None, read_opt = None))] + #[pyo3(signature = (chunk_size = 10000, backwards = false, from_key = None, read_opt = None))] fn chunked_items( &self, chunk_size: Option, @@ -726,7 +726,7 @@ impl Rdict { /// or the nearest next key for iteration /// (depending on iteration direction). /// read_opt: ReadOptions - #[pyo3(signature = (chunk_size = None, backwards = false, from_key = None, read_opt = None))] + #[pyo3(signature = (chunk_size = 10000, backwards = false, from_key = None, read_opt = None))] fn chunked_keys( &self, chunk_size: Option, @@ -781,7 +781,7 @@ impl Rdict { /// or the nearest next key for iteration /// (depending on iteration direction). /// read_opt: ReadOptions - #[pyo3(signature = (chunk_size = None, backwards = false, from_key = None, read_opt = None))] + #[pyo3(signature = (chunk_size = 10000, backwards = false, from_key = None, read_opt = None))] fn chunked_values( &self, chunk_size: Option, diff --git a/test/bench_rdict.py b/test/bench_rdict.py index d057a75..fa42b04 100644 --- a/test/bench_rdict.py +++ b/test/bench_rdict.py @@ -218,7 +218,7 @@ def perf_get(keys: List[bytes]): rand_bytes = gen_rand_bytes() NUM_THREADS = 4 - ITER_CHUNK_SIZE = 25_000 + ITER_CHUNK_SIZE = 10_000 print() print("Benchmarking Rdict Put...") diff --git a/test/test_rdict.py b/test/test_rdict.py index 9740c5e..c3c31db 100644 --- a/test/test_rdict.py +++ b/test/test_rdict.py @@ -216,7 +216,7 @@ def test_seek_backward_key(self): def test_chunk_keys_forward(self): it = self.test_dict.iter() it.seek_to_first() - test_list = it.get_chunk_keys() + test_list = it.get_chunk_keys(chunk_size=None) ref_list = [k for k in self.ref_dict.keys()] ref_list.sort() @@ -226,7 +226,7 @@ def test_chunk_keys_forward(self): def test_chunk_keys_backward(self): it = self.test_dict.iter() it.seek_to_last() - test_list = it.get_chunk_keys(backwards=True) + test_list = it.get_chunk_keys(chunk_size=None, backwards=True) ref_list = [k for k in self.ref_dict.keys()] ref_list.sort(reverse=True) @@ -260,7 +260,7 @@ def test_chunk_keys_backward_with_count(self): def test_chunk_items_forward(self): it = self.test_dict.iter() it.seek_to_first() - test_list = it.get_chunk_items() + test_list = it.get_chunk_items(chunk_size=None) ref_list = [k for k in self.ref_dict.items()] ref_list.sort() @@ -270,7 +270,7 @@ def test_chunk_items_forward(self): def test_chunk_items_backward(self): it = self.test_dict.iter() it.seek_to_last() - test_list = it.get_chunk_items(backwards=True) + test_list = it.get_chunk_items(chunk_size=None, backwards=True) ref_list = [k for k in self.ref_dict.items()] ref_list.sort(reverse=True) @@ -306,7 +306,7 @@ def test_chunk_items_backward_with_count(self): def test_chunk_values_forward(self): it = self.test_dict.iter() it.seek_to_first() - test_list = it.get_chunk_values() + test_list = it.get_chunk_values(chunk_size=None) ref_list = list(self.ref_dict.items()) ref_list.sort() @@ -317,7 +317,7 @@ def test_chunk_values_forward(self): def test_chunk_values_backward(self): it = self.test_dict.iter() it.seek_to_last() - test_list = it.get_chunk_values(backwards=True) + test_list = it.get_chunk_values(chunk_size=None, backwards=True) ref_list = list(self.ref_dict.items()) ref_list.sort(reverse=True) From 81b308c34762f52b66259c183b3f25569da4163f Mon Sep 17 00:00:00 2001 From: Chris Tam Date: Sat, 16 Dec 2023 12:53:17 -0500 Subject: [PATCH 4/4] Add iterations to single-thread batch iterator bench --- test/bench_rdict.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/bench_rdict.py b/test/bench_rdict.py index fa42b04..60263f7 100644 --- a/test/bench_rdict.py +++ b/test/bench_rdict.py @@ -123,16 +123,19 @@ def perf_iter(): rdict.close() -def perf_iterator_chunk_single_thread(rand_bytes: List[bytes], chunk_size: int): +def perf_iterator_chunk_single_thread( + rand_bytes: List[bytes], chunk_size: int, num_iterations: int +): rdict = Rdict("test.db", Options(raw_mode=True)) - items = [] start = time.perf_counter() - for batch in rdict.chunked_items(chunk_size=chunk_size): - items.extend(batch) + for _ in range(num_iterations): + items = [] + for batch in rdict.chunked_items(chunk_size=chunk_size): + items.extend(batch) end = time.perf_counter() - num_items = len(items) + num_items = len(items) * num_iterations secs = end - start item_per_sec = num_items / secs print( @@ -239,7 +242,9 @@ def perf_get(keys: List[bytes]): print() print("Benchmarking Rdict Batch Iterator...") - perf_iterator_chunk_single_thread(rand_bytes, chunk_size=ITER_CHUNK_SIZE) + perf_iterator_chunk_single_thread( + rand_bytes, chunk_size=ITER_CHUNK_SIZE, num_iterations=NUM_THREADS + ) perf_iterator_chunk_multi_thread( rand_bytes, num_threads=NUM_THREADS, batch_size=ITER_CHUNK_SIZE )