Skip to content

Commit

Permalink
Quotient Filter: merge (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
barrust authored Jan 13, 2024
1 parent 28a58b0 commit c1b8310
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 25 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# PyProbables Changelog

### Version 0.6.1

* Quotient Filter:
* Add ability to get hashes from the filter either as a list, or as a generator
* Add quotient filter expand capability, auto and on request
* Add QuotientFilterError exception
* Add merge functionality

### Version 0.6.0

* Add `QuotientFilter` implementation; [see issue #37](https://github.com/barrust/pyprobables/issues/37)
Expand Down
11 changes: 11 additions & 0 deletions probables/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,14 @@ class CountMinSketchError(ProbablesBaseException):
def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)


class QuotientFilterError(ProbablesBaseException):
"""Quotient Filter Exception
Args:
message (str): The error message to be reported"""

def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)
48 changes: 34 additions & 14 deletions probables/quotientfilter/quotientfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from array import array
from typing import Iterator, List, Optional

from probables.exceptions import QuotientFilterError
from probables.hashes import KeyT, SimpleHashT, fnv_1a_32
from probables.utilities import Bitarray

Expand All @@ -20,7 +21,7 @@ class QuotientFilter:
Returns:
QuotientFilter: The initialized filter
Raises:
ValueError:
QuotientFilterError: Raised when unable to initialize
Note:
The size of the QuotientFilter will be 2**q"""

Expand All @@ -44,8 +45,8 @@ def __init__(
self, quotient: int = 20, auto_expand: bool = True, hash_function: Optional[SimpleHashT] = None
): # needs to be parameterized
if quotient < 3 or quotient > 31:
raise ValueError(
f"Quotient filter: Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided"
raise QuotientFilterError(
f"Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided"
)
self.__set_params(quotient, auto_expand, hash_function)

Expand Down Expand Up @@ -140,20 +141,24 @@ def add(self, key: KeyT) -> None:
"""Add key to the quotient filter
Args:
key (str|bytes): The element to add"""
key (str|bytes): The element to add
Raises:
QuotientFilterError: Raised when no locations are available in which to insert"""
_hash = self._hash_func(key, 0)
self.add_alt(_hash)

def add_alt(self, _hash: int) -> None:
"""Add the pre-hashed value to the quotient filter
Args:
_hash (int): The element to add"""
_hash (int): The element to add
Raises:
QuotientFilterError: Raised when no locations are available in which to insert"""
if self._auto_resize and self.load_factor >= self._max_load_factor:
self.resize()
key_quotient = _hash >> self._r
key_remainder = _hash & ((1 << self._r) - 1)
if self._contained_at_loc(key_quotient, key_remainder) == -1:
if self._auto_resize and self.load_factor >= self._max_load_factor:
self.resize()
self._add(key_quotient, key_remainder)

def check(self, key: KeyT) -> bool:
Expand All @@ -177,7 +182,7 @@ def check_alt(self, _hash: int) -> bool:
key_remainder = _hash & ((1 << self._r) - 1)
return not self._contained_at_loc(key_quotient, key_remainder) == -1

def iter_hashes(self) -> Iterator[int]:
def hashes(self) -> Iterator[int]:
"""A generator over the hashes in the quotient filter
Yields:
Expand Down Expand Up @@ -220,25 +225,25 @@ def get_hashes(self) -> List[int]:
Returns:
list(int): The hash values stored in the quotient filter"""
return list(self.iter_hashes())
return list(self.hashes())

def resize(self, quotient: Optional[int] = None) -> None:
"""Resize the quotient filter to use the new quotient size
Args:
int: The new quotient to use
quotient (int): The new quotient to use
Note:
If `None` is provided, the quotient filter will double in size (quotient + 1)
Raises:
ValueError: When the new quotient will not accommodate the elements already added"""
QuotientFilterError: When the new quotient will not accommodate the elements already added"""
if quotient is None:
quotient = self._q + 1

if self.elements_added >= (1 << quotient):
raise ValueError("Unable to shrink since there will be too many elements in the quotient filter")
raise QuotientFilterError("Unable to shrink since there will be too many elements in the quotient filter")
if quotient < 3 or quotient > 31:
raise ValueError(
f"Quotient filter: Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided"
raise QuotientFilterError(
f"Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided"
)

hashes = self.get_hashes()
Expand All @@ -251,6 +256,19 @@ def resize(self, quotient: Optional[int] = None) -> None:
for _h in hashes:
self.add_alt(_h)

def merge(self, second: "QuotientFilter") -> None:
"""Merge the `second` quotient filter into the first
Args:
second (QuotientFilter): The quotient filter to merge
Note:
The hashing function between the two filters should match
Note:
Errors can occur if the quotient filter being inserted into does not expand (i.e., auto_expand=False)"""

for _h in second.hashes():
self.add_alt(_h)

def _shift_insert(self, k, v, start, j, flag):
if self._is_occupied[j] == 0 and self._is_continuation[j] == 0 and self._is_shifted[j] == 0:
self._filter[j] = v
Expand Down Expand Up @@ -311,6 +329,8 @@ def _get_start_index(self, k):
return j

def _add(self, q: int, r: int):
if self._size == self._elements_added:
raise QuotientFilterError("Unable to insert the element due to insufficient space")
if self._is_occupied[q] == 0 and self._is_continuation[q] == 0 and self._is_shifted[q] == 0:
self._filter[q] = r
self._is_occupied[q] = 1
Expand Down
61 changes: 50 additions & 11 deletions tests/quotientfilter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pathlib import Path
from tempfile import NamedTemporaryFile

from probables.exceptions import QuotientFilterError

this_dir = Path(__file__).parent
sys.path.insert(0, str(this_dir))
sys.path.insert(0, str(this_dir.parent))
Expand Down Expand Up @@ -49,6 +51,10 @@ def test_qf_init(self):
self.assertEqual(qf.num_elements, 16777216) # 2**qf.quotient
self.assertFalse(qf.auto_expand)

# reset auto_expand
qf.auto_expand = True
self.assertTrue(qf.auto_expand)

def test_qf_add_check(self):
"test that the qf is able to add and check elements"
qf = QuotientFilter(quotient=8)
Expand Down Expand Up @@ -91,10 +97,10 @@ def test_qf_add_check_in(self):

def test_qf_init_errors(self):
"""test quotient filter initialization errors"""
self.assertRaises(ValueError, lambda: QuotientFilter(quotient=2))
self.assertRaises(ValueError, lambda: QuotientFilter(quotient=32))
self.assertRaises(QuotientFilterError, lambda: QuotientFilter(quotient=2))
self.assertRaises(QuotientFilterError, lambda: QuotientFilter(quotient=32))

def test_retrieve_hashes(self):
def test_qf_retrieve_hashes(self):
"""test retrieving hashes back from the quotient filter"""
qf = QuotientFilter(quotient=8, auto_expand=False)
hashes = []
Expand All @@ -107,7 +113,7 @@ def test_retrieve_hashes(self):
self.assertEqual(qf.elements_added, len(out_hashes))
self.assertEqual(set(hashes), set(out_hashes))

def test_resize(self):
def test_qf_resize(self):
"""test resizing the quotient filter"""
qf = QuotientFilter(quotient=8, auto_expand=False)
for i in range(200):
Expand All @@ -120,7 +126,7 @@ def test_resize(self):
self.assertEqual(qf.bits_per_elm, 32)
self.assertFalse(qf.auto_expand)

self.assertRaises(ValueError, lambda: qf.resize(7)) # should be too small to fit
self.assertRaises(QuotientFilterError, lambda: qf.resize(7)) # should be too small to fit

qf.resize(17)
self.assertEqual(qf.elements_added, 200)
Expand All @@ -132,7 +138,7 @@ def test_resize(self):
for i in range(200):
self.assertTrue(qf.check(str(i)))

def test_auto_resize(self):
def test_qf_auto_resize(self):
"""test resizing the quotient filter automatically"""
qf = QuotientFilter(quotient=8, auto_expand=True)
self.assertEqual(qf.max_load_factor, 0.85)
Expand All @@ -153,7 +159,7 @@ def test_auto_resize(self):
self.assertEqual(qf.remainder, 23)
self.assertEqual(qf.bits_per_elm, 32)

def test_auto_resize_changed_max_load_factor(self):
def test_qf_auto_resize_changed_max_load_factor(self):
"""test resizing the quotient filter with a different load factor"""
qf = QuotientFilter(quotient=8, auto_expand=True)
self.assertEqual(qf.max_load_factor, 0.85)
Expand All @@ -178,13 +184,46 @@ def test_auto_resize_changed_max_load_factor(self):
self.assertEqual(qf.remainder, 23)
self.assertEqual(qf.bits_per_elm, 32)

def test_resize_errors(self):
def test_qf_resize_errors(self):
"""test resizing errors"""

qf = QuotientFilter(quotient=8, auto_expand=True)
for i in range(200):
qf.add(str(i))

self.assertRaises(ValueError, lambda: qf.resize(quotient=2))
self.assertRaises(ValueError, lambda: qf.resize(quotient=32))
self.assertRaises(ValueError, lambda: qf.resize(quotient=6))
self.assertRaises(QuotientFilterError, lambda: qf.resize(quotient=2))
self.assertRaises(QuotientFilterError, lambda: qf.resize(quotient=32))
self.assertRaises(QuotientFilterError, lambda: qf.resize(quotient=6))

def test_qf_merge(self):
"""test merging two quotient filters together"""
qf = QuotientFilter(quotient=8, auto_expand=True)
for i in range(200):
qf.add(str(i))

fq = QuotientFilter(quotient=8)
for i in range(300, 500):
fq.add(str(i))

qf.merge(fq)

for i in range(200):
self.assertTrue(qf.check(str(i)))
for i in range(200, 300):
self.assertFalse(qf.check(str(i)))
for i in range(300, 500):
self.assertTrue(qf.check(str(i)))

self.assertEqual(qf.elements_added, 400)

def test_qf_merge_error(self):
"""test unable to merge due to inability to grow"""
qf = QuotientFilter(quotient=8, auto_expand=False)
for i in range(200):
qf.add(str(i))

fq = QuotientFilter(quotient=8)
for i in range(300, 400):
fq.add(str(i))

self.assertRaises(QuotientFilterError, lambda: qf.merge(fq))

0 comments on commit c1b8310

Please sign in to comment.