From f70a519524d2ce4efbdf739ed874c2bad574c920 Mon Sep 17 00:00:00 2001 From: Chris Tam Date: Wed, 22 Nov 2023 13:10:54 -0500 Subject: [PATCH 1/2] Do not try to flush on secondary instances --- src/rdict.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/rdict.rs b/src/rdict.rs index 64161eb..1891f80 100644 --- a/src/rdict.rs +++ b/src/rdict.rs @@ -1020,11 +1020,14 @@ impl Rdict { fn close(&mut self) -> PyResult<()> { let f_opt = &self.flush_opt; let db = self.get_db()?.borrow(); - if let AccessTypeInner::ReadOnly { .. } = self.access_type.0 { - drop(db); - drop(self.column_family.take()); - drop(self.db.take()); - return Ok(()); + match &self.access_type.0 { + AccessTypeInner::ReadOnly { .. } | AccessTypeInner::Secondary { .. } => { + drop(db); + drop(self.column_family.take()); + drop(self.db.take()); + return Ok(()); + } + _ => (), }; let flush_wal_result = db.flush_wal(true); let flush_result = if let Some(cf) = &self.column_family { From 9437eeb10c8850fb5ad13024787451ea4059dc8e Mon Sep 17 00:00:00 2001 From: Chris Tam Date: Wed, 22 Nov 2023 19:19:30 -0500 Subject: [PATCH 2/2] Add tests for secondary index --- test/test_rdict.py | 168 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 139 insertions(+), 29 deletions(-) diff --git a/test/test_rdict.py b/test/test_rdict.py index abc1e4d..a7527c1 100644 --- a/test/test_rdict.py +++ b/test/test_rdict.py @@ -1,6 +1,14 @@ import unittest from sys import getrefcount -from rocksdict import Rdict, Options, PlainTableFactoryOptions, SliceTransform, CuckooTableOptions, DbClosedError +from rocksdict import ( + AccessType, + Rdict, + Options, + PlainTableFactoryOptions, + SliceTransform, + CuckooTableOptions, + DbClosedError, +) from random import randint, random, getrandbits import os import sys @@ -12,12 +20,10 @@ def randbytes(n): """Generate n random bytes.""" - return getrandbits(n * 8).to_bytes(n, 'little') + return getrandbits(n * 8).to_bytes(n, "little") -def compare_dicts(test_case: unittest.TestCase, - ref_dict: dict, - test_dict: Rdict): +def compare_dicts(test_case: unittest.TestCase, ref_dict: dict, test_dict: Rdict): # assert that the values are the same test_case.assertEqual({k: v for k, v in test_dict.items()}, ref_dict) @@ -64,8 +70,8 @@ class TestGetDelCustomDumpsLoads(unittest.TestCase): def setUpClass(cls) -> None: cls.opt = Options() cls.test_dict = Rdict(cls.path, cls.opt) - cls.test_dict.set_loads(lambda x: loads(x.decode('utf-8'))) - cls.test_dict.set_dumps(lambda x: bytes(dumps(x), 'utf-8')) + cls.test_dict.set_loads(lambda x: loads(x.decode("utf-8"))) + cls.test_dict.set_dumps(lambda x: bytes(dumps(x), "utf-8")) cls.test_dict["a"] = "a" cls.test_dict[123] = 123 cls.test_dict["ok"] = ["o", "k"] @@ -110,7 +116,9 @@ def setUpClass(cls) -> None: value = randbytes(20) cls.test_dict[key] = value cls.ref_dict[key] = value - keys_to_remove = list(set(randint(0, len(cls.ref_dict) - 1) for _ in range(50000))) + keys_to_remove = list( + set(randint(0, len(cls.ref_dict) - 1) for _ in range(50000)) + ) keys = [k for k in cls.ref_dict.keys()] keys_to_remove = [keys[i] for i in keys_to_remove] for key in keys_to_remove: @@ -127,7 +135,9 @@ def test_seek_backward_key(self): key = randbytes(20) ref_list = [k for k in self.ref_dict.keys() if k <= key] ref_list.sort(reverse=True) - self.assertEqual([k for k in self.test_dict.keys(from_key=key, backwards=True)], ref_list) + self.assertEqual( + [k for k in self.test_dict.keys(from_key=key, backwards=True)], ref_list + ) def test_may_exists(self): for k, v in self.ref_dict.items(): @@ -138,13 +148,17 @@ def test_may_exists(self): def test_seek_forward(self): key = randbytes(20) - self.assertEqual({k: v for k, v in self.test_dict.items(from_key=key)}, - {k: v for k, v in self.ref_dict.items() if k >= key}) + self.assertEqual( + {k: v for k, v in self.test_dict.items(from_key=key)}, + {k: v for k, v in self.ref_dict.items() if k >= key}, + ) def test_seek_backward(self): key = randbytes(20) - self.assertEqual({k: v for k, v in self.test_dict.items(from_key=key, backwards=True)}, - {k: v for k, v in self.ref_dict.items() if k <= key}) + self.assertEqual( + {k: v for k, v in self.test_dict.items(from_key=key, backwards=True)}, + {k: v for k, v in self.ref_dict.items() if k <= key}, + ) @classmethod def tearDownClass(cls): @@ -174,12 +188,16 @@ def setUpClass(cls) -> None: del cls.test_dict[key] def test_seek_forward(self): - self.assertEqual({k: v for k, v in self.test_dict.items()}, - {k: v for k, v in self.ref_dict.items()}) + self.assertEqual( + {k: v for k, v in self.test_dict.items()}, + {k: v for k, v in self.ref_dict.items()}, + ) def test_seek_backward(self): - self.assertEqual({k: v for k, v in self.test_dict.items(backwards=True)}, - {k: v for k, v in self.ref_dict.items()}) + self.assertEqual( + {k: v for k, v in self.test_dict.items(backwards=True)}, + {k: v for k, v in self.ref_dict.items()}, + ) def test_seek_forward_key(self): key = randint(0, TEST_INT_RANGE_UPPER - 1) @@ -191,7 +209,9 @@ def test_seek_backward_key(self): key = randint(0, TEST_INT_RANGE_UPPER - 1) ref_list = [k for k in self.ref_dict.keys() if k <= key] ref_list.sort(reverse=True) - self.assertEqual([k for k in self.test_dict.keys(from_key=key, backwards=True)], ref_list) + self.assertEqual( + [k for k in self.test_dict.keys(from_key=key, backwards=True)], ref_list + ) @classmethod def tearDownClass(cls): @@ -250,7 +270,10 @@ def test_reopen(self): def test_get_batch(self): keys = list(self.ref_dict.keys())[:100] - self.assertEqual(self.test_dict[keys + ["no such key"] * 3], [self.ref_dict[k] for k in keys] + [None] * 3) + self.assertEqual( + self.test_dict[keys + ["no such key"] * 3], + [self.ref_dict[k] for k in keys] + [None] * 3, + ) @classmethod def tearDownClass(cls): @@ -327,7 +350,10 @@ def test_reopen(self): def test_get_batch(self): keys = list(self.ref_dict.keys())[:100] - self.assertEqual(self.test_dict[keys + ["no such key"] * 3], [self.ref_dict[k] for k in keys] + [None] * 3) + self.assertEqual( + self.test_dict[keys + ["no such key"] * 3], + [self.ref_dict[k] for k in keys] + [None] * 3, + ) @classmethod def tearDownClass(cls): @@ -345,7 +371,7 @@ def setUpClass(cls) -> None: cls.opt = Options() cls.opt.create_if_missing(True) # for the moment do not use CuckooTable on windows - if not sys.platform.startswith('win'): + if not sys.platform.startswith("win"): cls.opt.set_cuckoo_table_factory(CuckooTableOptions()) cls.opt.set_allow_mmap_reads(True) cls.opt.set_allow_mmap_writes(True) @@ -388,7 +414,10 @@ def test_reopen(self): def test_get_batch(self): keys = list(self.ref_dict.keys())[:100] - self.assertEqual(self.test_dict[keys + ["no such key"] * 3], [self.ref_dict[k] for k in keys] + [None] * 3) + self.assertEqual( + self.test_dict[keys + ["no such key"] * 3], + [self.ref_dict[k] for k in keys] + [None] * 3, + ) @classmethod def tearDownClass(cls): @@ -568,7 +597,9 @@ def setUpClass(cls) -> None: def test_column_families_custom_options_auto_reopen(self): ds = self.test_dict.create_column_family(name="string") - di = self.test_dict.create_column_family(name="integer", options=self.plain_opts) + di = self.test_dict.create_column_family( + name="integer", options=self.plain_opts + ) for i in range(1000): di[i] = i * i @@ -613,7 +644,9 @@ def setUpClass(cls) -> None: def test_column_families_custom_options_auto_reopen_override(self): ds = self.test_dict.create_column_family(name="string") - di = self.test_dict.create_column_family(name="integer", options=self.plain_opts) + di = self.test_dict.create_column_family( + name="integer", options=self.plain_opts + ) for i in range(1000): di[i] = i * i @@ -633,9 +666,9 @@ def test_column_families_custom_options_auto_reopen_override(self): ds = self.test_dict.get_column_family("string") di = self.test_dict.get_column_family("integer") db = self.test_dict.get_column_family("bytes") - db[b'great'] = b'hello world' + db[b"great"] = b"hello world" assert self.test_dict["ok"] - assert db[b'great'] == b'hello world' + assert db[b"great"] == b"hello world" compare_dicts(self, {i: i**2 for i in range(1000)}, di) compare_dicts(self, {str(i): str(i**2) for i in range(1000)}, ds) ds.close() @@ -648,9 +681,9 @@ def test_column_families_custom_options_auto_reopen_override(self): ds = self.test_dict.get_column_family("string") di = self.test_dict.get_column_family("integer") db = self.test_dict.get_column_family("bytes") - db[b'great'] = b'hello world' + db[b"great"] = b"hello world" assert self.test_dict["ok"] - assert db[b'great'] == b'hello world' + assert db[b"great"] == b"hello world" compare_dicts(self, {i: i**2 for i in range(1000)}, di) compare_dicts(self, {str(i): str(i**2) for i in range(1000)}, ds) ds.close() @@ -664,5 +697,82 @@ def tearDownClass(cls): Rdict.destroy(cls.path) -if __name__ == '__main__': +class TestIntWithSecondary(unittest.TestCase): + test_dict = None + ref_dict = None + secondary_dict = None + opt = None + path = "./temp_int_with_secondary" + secondary_path = "./temp_int_with_secondary.secondary" + + @classmethod + def setUpClass(cls) -> None: + cls.opt = Options() + cls.opt.create_if_missing(True) + cls.test_dict = Rdict(cls.path, cls.opt) + + cls.secondary_dict = Rdict( + cls.path, + options=cls.opt, + access_type=AccessType.secondary(cls.secondary_path), + ) + + cls.ref_dict = dict() + + def test_add_integer(self): + for i in range(10000): + key = randint(0, TEST_INT_RANGE_UPPER - 1) + value = randint(0, TEST_INT_RANGE_UPPER - 1) + self.ref_dict[key] = value + self.test_dict[key] = value + + self.test_dict.flush(True) + self.secondary_dict.try_catch_up_with_primary() + compare_dicts(self, self.ref_dict, self.secondary_dict) + + def test_delete_integer(self): + for i in range(5000): + key = randint(0, TEST_INT_RANGE_UPPER - 1) + if key in self.ref_dict: + del self.ref_dict[key] + del self.test_dict[key] + + self.test_dict.flush(True) + self.secondary_dict.try_catch_up_with_primary() + compare_dicts(self, self.ref_dict, self.secondary_dict) + + def test_delete_range(self): + to_delete = [] + for key in self.ref_dict: + if key >= 99999: + to_delete.append(key) + for key in to_delete: + del self.ref_dict[key] + self.test_dict.delete_range(99999, 10000000) + + self.test_dict.flush(True) + self.secondary_dict.try_catch_up_with_primary() + compare_dicts(self, self.ref_dict, self.secondary_dict) + + def test_reopen(self): + self.secondary_dict.close() + + self.assertRaises(DbClosedError, lambda: self.secondary_dict.get(1)) + + self.secondary_dict = Rdict( + self.path, + options=self.opt, + access_type=AccessType.secondary(self.secondary_path), + ) + compare_dicts(self, self.ref_dict, self.secondary_dict) + + @classmethod + def tearDownClass(cls): + del cls.test_dict + del cls.secondary_dict + Rdict.destroy(cls.path, cls.opt) + Rdict.destroy(cls.secondary_path, cls.opt) + + +if __name__ == "__main__": unittest.main()