-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_BHHash.py
53 lines (47 loc) · 1.64 KB
/
test_BHHash.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# pylint: disable=W0621, C0116
import numpy as np
import pytest
from P2HNNS.utils import Query
from P2HNNS.methods import BHHash
from P2HNNS.utils.distance_functions import DistDP2H
@pytest.fixture
def bhhash_setup():
d = 10 # Dimensionality of the data.
m = 5 # Number of hyperplanes per hash function.
l = 3 # Number of hash functions.
n = 20 # Expected number of data points
bhhash = BHHash(d=d, m=m, l=l, n=n)
return bhhash
def test_initialization(bhhash_setup):
bhhash = bhhash_setup
assert bhhash.m == 5
assert bhhash.l == 3
assert hasattr(bhhash, 'randu')
assert hasattr(bhhash, 'randv')
assert bhhash.randu.shape == (bhhash.m * bhhash.l * 10,)
assert bhhash.randv.shape == (bhhash.m * bhhash.l * 10,)
def test_hash_data(bhhash_setup):
bhhash = bhhash_setup
data = np.random.rand(10)
hash_codes = bhhash.hash_data(data)
assert len(hash_codes) == bhhash.l
assert np.issubdtype(hash_codes.dtype, np.integer)
def test_hash_query(bhhash_setup):
bhhash = bhhash_setup
query = np.random.rand(10)
hash_codes = bhhash.hash_query(query)
assert len(hash_codes) == bhhash.l
assert np.issubdtype(hash_codes.dtype, np.integer)
def test_build_index(bhhash_setup):
bhhash = bhhash_setup
data = np.random.rand(20, 10)
bhhash.build_index(data)
assert bhhash.buckets.is_empty() is False
def test_nns(bhhash_setup):
bhhash = bhhash_setup
data = np.random.rand(20, 10)
query = np.random.rand(10)
bhhash.build_index(data)
param = Query(query=query, data=data, top=5, limit=10, dist=DistDP2H())
results = bhhash.nns(param)
assert len(results) <= 5