From 5c92159c497e733b3efa59767d46f0aa1dc99a46 Mon Sep 17 00:00:00 2001 From: Matteo NERI Date: Tue, 3 Sep 2024 17:11:04 +0200 Subject: [PATCH] final changes, new version of entropy_hist --- examples/it/plot_mi.py | 1 - hoi/core/entropies.py | 23 ++++------------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/examples/it/plot_mi.py b/examples/it/plot_mi.py index 05e0a4eb..454dc6c5 100644 --- a/examples/it/plot_mi.py +++ b/examples/it/plot_mi.py @@ -37,7 +37,6 @@ # create a special function for the binning approach as it requires binary data mi_binning_fcn = get_mi("binning", base=2) - def mi_binning(x, y, **kwargs): x = digitize(x.T, **kwargs).T y = digitize(y.T, **kwargs).T diff --git a/hoi/core/entropies.py b/hoi/core/entropies.py index 823c8604..fe44165d 100644 --- a/hoi/core/entropies.py +++ b/hoi/core/entropies.py @@ -295,20 +295,6 @@ def entropy_bin(x: jnp.array, base: int = 2) -> jnp.array: ############################################################################### -@partial(jax.jit, static_argnums=(1,)) -def digitize_1d_hist(x: jnp.array, n_bins: int = 8): - """One dimensional digitization.""" - assert x.ndim == 1 - x_min, x_max = x.min(), x.max() - dx = (x_max - x_min) / n_bins - x_binned = ((x - x_min) / dx).astype(int) - x_binned = jnp.minimum(x_binned, n_bins - 1) - return x_binned.astype(int) - - -# digitize_hist_2d = jax.jit(jax.vmap(digitize_1d_hist, (0, ), 0)) - - @partial(jax.jit, static_argnums=(1, 2)) def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array: """Entropy using binning. @@ -333,12 +319,11 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array: bins_arr = (x.max(axis=1) - x.min(axis=1)) / n_bins bin_size = jnp.prod(bins_arr) - digitize_hist_2d = jax.vmap( - partial(digitize_1d_hist, n_bins=n_bins), in_axes=0 - ) - # binning of the data - x_binned = digitize_hist_2d(x) + x_min, x_max = x.min(axis=1, keepdims=True), x.max(axis=1, keepdims=True) + dx = (x_max - x_min) / n_bins + x_binned = ((x - x_min) / dx).astype(int) + x_binned = jnp.minimum(x_binned, n_bins - 1).astype(int) n_features, n_samples = x_binned.shape