diff --git a/examples/it/plot_entropies.py b/examples/it/plot_entropies.py index 77f67619..cb6ba476 100644 --- a/examples/it/plot_entropies.py +++ b/examples/it/plot_entropies.py @@ -35,8 +35,8 @@ # list of estimators to compare metrics = { "GC": get_entropy("gc"), - "KNN-3": get_entropy("knn", k=3), "Gaussian": get_entropy(method="gauss"), + "KNN-3": get_entropy("knn", k=3), "Kernel": get_entropy("kernel"), } diff --git a/examples/it/plot_mi.py b/examples/it/plot_mi.py index 74fbfe3a..05e0a4eb 100644 --- a/examples/it/plot_mi.py +++ b/examples/it/plot_mi.py @@ -39,12 +39,8 @@ def mi_binning(x, y, **kwargs): - - bin_x, _ = digitize(x.T, **kwargs) - bin_y, _ = digitize(y.T, **kwargs) - - x = bin_x.T - y = bin_y.T + x = digitize(x.T, **kwargs).T + y = digitize(y.T, **kwargs).T return mi_binning_fcn(x, y) diff --git a/hoi/core/entropies.py b/hoi/core/entropies.py index 4731adbc..823c8604 100644 --- a/hoi/core/entropies.py +++ b/hoi/core/entropies.py @@ -12,7 +12,7 @@ from jax.scipy.stats import gaussian_kde from hoi.utils.logging import logger -from hoi.utils.stats import normalize, digitize_hist +from hoi.utils.stats import normalize ############################################################################### ############################################################################### @@ -257,7 +257,7 @@ def entropy_gauss(x: jnp.array) -> jnp.array: @partial(jax.jit, static_argnums=(1,)) -def entropy_bin(x: jnp.array, base: float = 2) -> jnp.array: +def entropy_bin(x: jnp.array, base: int = 2) -> jnp.array: """Entropy using binning. Parameters @@ -265,11 +265,8 @@ def entropy_bin(x: jnp.array, base: float = 2) -> jnp.array: x : array_like Input data of shape (n_features, n_samples). The data should already be discretize - base : float | 2 + base : int | 2 The logarithmic base to use. Default is base 2. - bin_size : float | None - The size of all the bins. Will be taken in consideration only if all - bins have the same size, for histogram estimator. Returns ------- @@ -288,7 +285,7 @@ def entropy_bin(x: jnp.array, base: float = 2) -> jnp.array: )[1] probs = counts / n_samples - return (jax.scipy.special.entr(probs)).sum() / jnp.log(base) + return jax.scipy.special.entr(probs).sum() / jnp.log(base) ############################################################################### @@ -299,6 +296,20 @@ def entropy_bin(x: jnp.array, base: float = 2) -> jnp.array: @partial(jax.jit, static_argnums=(1,)) +def digitize_1d_hist(x: jnp.array, n_bins: int = 8): + """One dimensional digitization.""" + assert x.ndim == 1 + x_min, x_max = x.min(), x.max() + dx = (x_max - x_min) / n_bins + x_binned = ((x - x_min) / dx).astype(int) + x_binned = jnp.minimum(x_binned, n_bins - 1) + return x_binned.astype(int) + + +# digitize_hist_2d = jax.jit(jax.vmap(digitize_1d_hist, (0, ), 0)) + + +@partial(jax.jit, static_argnums=(1, 2)) def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array: """Entropy using binning. @@ -318,7 +329,16 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array: Entropy of x (in bits) """ - x_binned, bin_size = digitize_hist(x, n_bins, axis=1) + # bin size computation + bins_arr = (x.max(axis=1) - x.min(axis=1)) / n_bins + bin_size = jnp.prod(bins_arr) + + digitize_hist_2d = jax.vmap( + partial(digitize_1d_hist, n_bins=n_bins), in_axes=0 + ) + + # binning of the data + x_binned = digitize_hist_2d(x) n_features, n_samples = x_binned.shape diff --git a/hoi/utils/stats.py b/hoi/utils/stats.py index 84ce839c..34dda36f 100644 --- a/hoi/utils/stats.py +++ b/hoi/utils/stats.py @@ -181,48 +181,6 @@ def digitize(x, n_bins, axis=0, use_sklearn=False, **kwargs): return np.apply_along_axis(digitize_sklearn, axis, x, **kwargs) -def digitize_1d_hist(x, n_bins): - """One dimensional digitization.""" - assert x.ndim == 1 - x_min, x_max = x.min(), x.max() - dx = (x_max - x_min) / n_bins - x_binned = ((x - x_min) / dx).astype(int) - x_binned = jnp.minimum(x_binned, n_bins - 1) - return x_binned.astype(int) - - -def digitize_hist(x, n_bins, axis=0, **kwargs): - """Discretize a continuous variable. - - Parameters - ---------- - x : array_like - Array to discretize - n_bins : int - Number of bins - axis : int | 0 - Axis along which to perform the discretization. By default, - discretization is performed along the first axis (n_samples,) - kwargs : dict | {} - Additional arguments are passed to - sklearn.preprocessing.KBinsDiscretizer. For example, use - `strategy='quantile'` for equal population binning. - - Returns - ------- - x_binned : array_like - Digitized array with the same shape as x - b_size : float - Size of the bin used - """ - # In case use_sklearn = False, all bins have the same size. In this case, - # in order to allow the histogram estimator, also the size of the bins is - # returned. - bins_arr = (x.max(axis=axis) - x.min(axis=axis)) / n_bins - b_size = jnp.prod(bins_arr) - return jnp.apply_along_axis(digitize_1d_hist, axis, x, n_bins), b_size - - partial(jax.jit, static_argnums=(1, 2)) diff --git a/hoi/utils/tests/test_stats.py b/hoi/utils/tests/test_stats.py index 379a0e7e..1d90238c 100644 --- a/hoi/utils/tests/test_stats.py +++ b/hoi/utils/tests/test_stats.py @@ -31,7 +31,6 @@ class TestStats(object): @pytest.mark.parametrize("bins", [n + 2 for n in range(5)]) @pytest.mark.parametrize("sklearn", [True, False]) def test_digitize(self, arr, bins, sklearn): - x_binned = digitize(x=arr, n_bins=bins, axis=0, use_sklearn=sklearn) assert arr.shape == x_binned.shape for row in x_binned: