diff --git a/hoi/core/tests/test_entropies.py b/hoi/core/tests/test_entropies.py index 82bc4a12..02991f3f 100644 --- a/hoi/core/tests/test_entropies.py +++ b/hoi/core/tests/test_entropies.py @@ -40,7 +40,7 @@ def test_entropy_gc(self, x, biascorrect, copnorm): @pytest.mark.parametrize("x", [x1, x2, j1, j2]) def test_entropy_bin(self, x): - x_bin = digitize(x, n_bins=3) + x_bin, _ = digitize(x, n_bins=3) hx = entropy_bin(x_bin) hx = np.asarray(hx) assert hx.dtype == np.float32 diff --git a/hoi/core/tests/test_mi.py b/hoi/core/tests/test_mi.py index a05ebacb..d6bd9f55 100644 --- a/hoi/core/tests/test_mi.py +++ b/hoi/core/tests/test_mi.py @@ -34,8 +34,8 @@ def test_mi_gc(self, xy, biascorrect, copnorm): @pytest.mark.parametrize("xy", [(x1, x2), (j1, j2)]) def test_mi_bin(self, xy): mi_fcn = get_mi(method="binning") - x_binned = digitize(xy[0], n_bins=3) - y_binned = digitize(xy[1], n_bins=3) + x_binned, _ = digitize(xy[0], n_bins=3) + y_binned, _ = digitize(xy[1], n_bins=3) mi = mi_fcn(x_binned, y_binned) assert mi.dtype == np.float32 assert mi.shape == () diff --git a/hoi/utils/tests/test_stats.py b/hoi/utils/tests/test_stats.py index 1d90238c..d441a43c 100644 --- a/hoi/utils/tests/test_stats.py +++ b/hoi/utils/tests/test_stats.py @@ -31,11 +31,18 @@ class TestStats(object): @pytest.mark.parametrize("bins", [n + 2 for n in range(5)]) @pytest.mark.parametrize("sklearn", [True, False]) def test_digitize(self, arr, bins, sklearn): - x_binned = digitize(x=arr, n_bins=bins, axis=0, use_sklearn=sklearn) - assert arr.shape == x_binned.shape - for row in x_binned: - for val in row: - assert isinstance(val, np.int64) + if sklearn: + x_binned = digitize(x=arr, n_bins=bins, axis=0, use_sklearn=sklearn) + assert arr.shape == x_binned.shape + for row in x_binned: + for val in row: + assert isinstance(val, np.int64) + else: + x_binned, _ = digitize(x=arr, n_bins=bins, axis=0, use_sklearn=sklearn) + assert arr.shape == x_binned.shape + for row in x_binned: + for val in row: + assert isinstance(val, np.int64) @pytest.mark.parametrize("x", [x1, x2, j2]) @pytest.mark.parametrize("minmax", [(-10.0, 10.0), (-1.0, 1.0)])