Skip to content

Commit

Permalink
checking for the test
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo NERI authored and Matteo NERI committed Jul 8, 2024
1 parent d415a32 commit dd1364d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion hoi/core/tests/test_entropies.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_entropy_gc(self, x, biascorrect, copnorm):

@pytest.mark.parametrize("x", [x1, x2, j1, j2])
def test_entropy_bin(self, x):
x_bin = digitize(x, n_bins=3)
x_bin, _ = digitize(x, n_bins=3)
hx = entropy_bin(x_bin)
hx = np.asarray(hx)
assert hx.dtype == np.float32
Expand Down
4 changes: 2 additions & 2 deletions hoi/core/tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def test_mi_gc(self, xy, biascorrect, copnorm):
@pytest.mark.parametrize("xy", [(x1, x2), (j1, j2)])
def test_mi_bin(self, xy):
mi_fcn = get_mi(method="binning")
x_binned = digitize(xy[0], n_bins=3)
y_binned = digitize(xy[1], n_bins=3)
x_binned, _ = digitize(xy[0], n_bins=3)
y_binned, _ = digitize(xy[1], n_bins=3)
mi = mi_fcn(x_binned, y_binned)
assert mi.dtype == np.float32
assert mi.shape == ()
Expand Down
17 changes: 12 additions & 5 deletions hoi/utils/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@ class TestStats(object):
@pytest.mark.parametrize("bins", [n + 2 for n in range(5)])
@pytest.mark.parametrize("sklearn", [True, False])
def test_digitize(self, arr, bins, sklearn):
x_binned = digitize(x=arr, n_bins=bins, axis=0, use_sklearn=sklearn)
assert arr.shape == x_binned.shape
for row in x_binned:
for val in row:
assert isinstance(val, np.int64)
if sklearn:
x_binned = digitize(x=arr, n_bins=bins, axis=0, use_sklearn=sklearn)
assert arr.shape == x_binned.shape
for row in x_binned:
for val in row:
assert isinstance(val, np.int64)
else:
x_binned, _ = digitize(x=arr, n_bins=bins, axis=0, use_sklearn=sklearn)
assert arr.shape == x_binned.shape
for row in x_binned:
for val in row:
assert isinstance(val, np.int64)

@pytest.mark.parametrize("x", [x1, x2, j2])
@pytest.mark.parametrize("minmax", [(-10.0, 10.0), (-1.0, 1.0)])
Expand Down

0 comments on commit dd1364d

Please sign in to comment.