Skip to content

Commit

Permalink
final changes, new version of entropy_hist
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo NERI authored and Matteo NERI committed Sep 3, 2024
1 parent 5ed5772 commit 5c92159
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 20 deletions.
1 change: 0 additions & 1 deletion examples/it/plot_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
# create a special function for the binning approach as it requires binary data
mi_binning_fcn = get_mi("binning", base=2)


def mi_binning(x, y, **kwargs):
x = digitize(x.T, **kwargs).T
y = digitize(y.T, **kwargs).T
Expand Down
23 changes: 4 additions & 19 deletions hoi/core/entropies.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,20 +295,6 @@ def entropy_bin(x: jnp.array, base: int = 2) -> jnp.array:
###############################################################################


@partial(jax.jit, static_argnums=(1,))
def digitize_1d_hist(x: jnp.array, n_bins: int = 8):
"""One dimensional digitization."""
assert x.ndim == 1
x_min, x_max = x.min(), x.max()
dx = (x_max - x_min) / n_bins
x_binned = ((x - x_min) / dx).astype(int)
x_binned = jnp.minimum(x_binned, n_bins - 1)
return x_binned.astype(int)


# digitize_hist_2d = jax.jit(jax.vmap(digitize_1d_hist, (0, ), 0))


@partial(jax.jit, static_argnums=(1, 2))
def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array:
"""Entropy using binning.
Expand All @@ -333,12 +319,11 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array:
bins_arr = (x.max(axis=1) - x.min(axis=1)) / n_bins
bin_size = jnp.prod(bins_arr)

digitize_hist_2d = jax.vmap(
partial(digitize_1d_hist, n_bins=n_bins), in_axes=0
)

# binning of the data
x_binned = digitize_hist_2d(x)
x_min, x_max = x.min(axis=1, keepdims=True), x.max(axis=1, keepdims=True)
dx = (x_max - x_min) / n_bins
x_binned = ((x - x_min) / dx).astype(int)
x_binned = jnp.minimum(x_binned, n_bins - 1).astype(int)

n_features, n_samples = x_binned.shape

Expand Down

0 comments on commit 5c92159

Please sign in to comment.