Skip to content

Commit

Permalink
cleaning and rearranging the pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo NERI authored and Matteo NERI committed Sep 3, 2024
1 parent 77c2432 commit 5ed5772
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 58 deletions.
2 changes: 1 addition & 1 deletion examples/it/plot_entropies.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
# list of estimators to compare
metrics = {
"GC": get_entropy("gc"),
"KNN-3": get_entropy("knn", k=3),
"Gaussian": get_entropy(method="gauss"),
"KNN-3": get_entropy("knn", k=3),
"Kernel": get_entropy("kernel"),
}

Expand Down
8 changes: 2 additions & 6 deletions examples/it/plot_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,8 @@


def mi_binning(x, y, **kwargs):

bin_x, _ = digitize(x.T, **kwargs)
bin_y, _ = digitize(y.T, **kwargs)

x = bin_x.T
y = bin_y.T
x = digitize(x.T, **kwargs).T
y = digitize(y.T, **kwargs).T

return mi_binning_fcn(x, y)

Expand Down
36 changes: 28 additions & 8 deletions hoi/core/entropies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jax.scipy.stats import gaussian_kde

from hoi.utils.logging import logger
from hoi.utils.stats import normalize, digitize_hist
from hoi.utils.stats import normalize

###############################################################################
###############################################################################
Expand Down Expand Up @@ -257,19 +257,16 @@ def entropy_gauss(x: jnp.array) -> jnp.array:


@partial(jax.jit, static_argnums=(1,))
def entropy_bin(x: jnp.array, base: float = 2) -> jnp.array:
def entropy_bin(x: jnp.array, base: int = 2) -> jnp.array:
"""Entropy using binning.
Parameters
----------
x : array_like
Input data of shape (n_features, n_samples). The data should already
be discretize
base : float | 2
base : int | 2
The logarithmic base to use. Default is base 2.
bin_size : float | None
The size of all the bins. Will be taken in consideration only if all
bins have the same size, for histogram estimator.
Returns
-------
Expand All @@ -288,7 +285,7 @@ def entropy_bin(x: jnp.array, base: float = 2) -> jnp.array:
)[1]
probs = counts / n_samples

return (jax.scipy.special.entr(probs)).sum() / jnp.log(base)
return jax.scipy.special.entr(probs).sum() / jnp.log(base)


###############################################################################
Expand All @@ -299,6 +296,20 @@ def entropy_bin(x: jnp.array, base: float = 2) -> jnp.array:


@partial(jax.jit, static_argnums=(1,))
def digitize_1d_hist(x: jnp.array, n_bins: int = 8):
"""One dimensional digitization."""
assert x.ndim == 1
x_min, x_max = x.min(), x.max()
dx = (x_max - x_min) / n_bins
x_binned = ((x - x_min) / dx).astype(int)
x_binned = jnp.minimum(x_binned, n_bins - 1)
return x_binned.astype(int)


# digitize_hist_2d = jax.jit(jax.vmap(digitize_1d_hist, (0, ), 0))


@partial(jax.jit, static_argnums=(1, 2))
def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array:
"""Entropy using binning.
Expand All @@ -318,7 +329,16 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array:
Entropy of x (in bits)
"""

x_binned, bin_size = digitize_hist(x, n_bins, axis=1)
# bin size computation
bins_arr = (x.max(axis=1) - x.min(axis=1)) / n_bins
bin_size = jnp.prod(bins_arr)

digitize_hist_2d = jax.vmap(
partial(digitize_1d_hist, n_bins=n_bins), in_axes=0
)

# binning of the data
x_binned = digitize_hist_2d(x)

n_features, n_samples = x_binned.shape

Expand Down
42 changes: 0 additions & 42 deletions hoi/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,48 +181,6 @@ def digitize(x, n_bins, axis=0, use_sklearn=False, **kwargs):
return np.apply_along_axis(digitize_sklearn, axis, x, **kwargs)


def digitize_1d_hist(x, n_bins):
"""One dimensional digitization."""
assert x.ndim == 1
x_min, x_max = x.min(), x.max()
dx = (x_max - x_min) / n_bins
x_binned = ((x - x_min) / dx).astype(int)
x_binned = jnp.minimum(x_binned, n_bins - 1)
return x_binned.astype(int)


def digitize_hist(x, n_bins, axis=0, **kwargs):
"""Discretize a continuous variable.
Parameters
----------
x : array_like
Array to discretize
n_bins : int
Number of bins
axis : int | 0
Axis along which to perform the discretization. By default,
discretization is performed along the first axis (n_samples,)
kwargs : dict | {}
Additional arguments are passed to
sklearn.preprocessing.KBinsDiscretizer. For example, use
`strategy='quantile'` for equal population binning.
Returns
-------
x_binned : array_like
Digitized array with the same shape as x
b_size : float
Size of the bin used
"""
# In case use_sklearn = False, all bins have the same size. In this case,
# in order to allow the histogram estimator, also the size of the bins is
# returned.
bins_arr = (x.max(axis=axis) - x.min(axis=axis)) / n_bins
b_size = jnp.prod(bins_arr)
return jnp.apply_along_axis(digitize_1d_hist, axis, x, n_bins), b_size


partial(jax.jit, static_argnums=(1, 2))


Expand Down
1 change: 0 additions & 1 deletion hoi/utils/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class TestStats(object):
@pytest.mark.parametrize("bins", [n + 2 for n in range(5)])
@pytest.mark.parametrize("sklearn", [True, False])
def test_digitize(self, arr, bins, sklearn):

x_binned = digitize(x=arr, n_bins=bins, axis=0, use_sklearn=sklearn)
assert arr.shape == x_binned.shape
for row in x_binned:
Expand Down

0 comments on commit 5ed5772

Please sign in to comment.