diff --git a/hoi/core/entropies.py b/hoi/core/entropies.py index fe44165d..63f5c7f0 100644 --- a/hoi/core/entropies.py +++ b/hoi/core/entropies.py @@ -317,7 +317,7 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array: # bin size computation bins_arr = (x.max(axis=1) - x.min(axis=1)) / n_bins - bin_size = jnp.prod(bins_arr) + bin_s = jnp.prod(bins_arr) # binning of the data x_min, x_max = x.min(axis=1, keepdims=True), x.max(axis=1, keepdims=True) @@ -338,9 +338,7 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array: probs = counts / n_samples - bin_s = jnp.where(probs != 0, bin_size, 0) - - return -jax.scipy.special.rel_entr(probs, bin_s).sum() / jnp.log(base) + return bin_s * jax.scipy.special.entr(probs / bin_s).sum() / jnp.log(base) ###############################################################################