Skip to content

Commit

Permalink
Merge pull request #38 from chrisferreyra13/main
Browse files Browse the repository at this point in the history
Fixing MI estimator with KNN
  • Loading branch information
EtienneCmb authored Apr 8, 2024
2 parents 4753bee + b1dd2b8 commit c6edaa6
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 5 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ ENV/
env.bak/
venv.bak/

# Pipenv
Pipfile
Pipfile.lock

# Spyder project settings
.spyderproject
.spyproject
Expand Down
81 changes: 76 additions & 5 deletions hoi/core/mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
import jax.numpy as jnp
from jax.scipy.special import digamma as psi

from .entropies import get_entropy

Expand All @@ -28,11 +29,14 @@ def get_mi(method="gcmi", **kwargs):
Function to compute mutual information on variables of shapes
(n_features, n_samples)
"""
# get the entropy unction
_entropy = get_entropy(method=method, **kwargs)
if method == "knn":
return partial(compute_mi_knn, **kwargs)
else:
# get the entropy function
_entropy = get_entropy(method=method, **kwargs)

# wrap the mi function with it
return partial(compute_mi, entropy_fcn=_entropy)
# wrap the mi function with it
return partial(compute_mi, entropy_fcn=_entropy)


###############################################################################
Expand All @@ -57,7 +61,7 @@ def compute_mi_comb(inputs, comb, mi=None):

###############################################################################
###############################################################################
# OTHERS
# GENERAL MUTUAL INFORMATION
###############################################################################
###############################################################################

Expand Down Expand Up @@ -87,3 +91,70 @@ def compute_mi(x, y, entropy_fcn=None):
- entropy_fcn(jnp.concatenate((x, y), axis=0))
)
return mi


###############################################################################
###############################################################################
# KNN MUTUAL INFORMATION
###############################################################################
###############################################################################


@partial(jax.jit, static_argnums=(2,))
def n_neighbours(xy, idx, k=1):
"""Return number of neighbours for each point based on kth neighbour."""
xi, x = xy[0][:, [idx]], xy[0]
yi, y = xy[1][:, [idx]], xy[1]

# compute euclidian distance from xi to all points in x (same y)
eucl_xi = jnp.sqrt(jnp.sum((xi - x) ** 2, axis=0))
eucl_yi = jnp.sqrt(jnp.sum((yi - y) ** 2, axis=0))

# distance in space (XxY) is the maximum distance.
max_dist_xy = jnp.maximum(eucl_xi, eucl_yi)
# indices to the closest points in the (XxY) space.
closest_points = jnp.argsort(max_dist_xy)
# the kth neighbour is at index k (ignoring the point itself)
# distance to the k-th neighbor for each point
dist_k = max_dist_xy[closest_points[k]]
# don't include the `i`th point itself in nx and ny
nx = (eucl_xi < dist_k).sum() - 1
ny = (eucl_yi < dist_k).sum() - 1

return xy, (nx, ny)


@partial(jax.jit, static_argnums=(2,))
def compute_mi_knn(x, y, k: int = 1) -> jnp.array:
"""Mutual information using the KSG estimator.
First algorithm proposed in Kraskov et al., Estimating mutual information,
Phy rev, 2004.
Parameters
----------
x, y : array_like
Input data of shape (n_features, n_samples).
k : int
Number of nearest neighbors to consider for the KSG estimator.
Returns
-------
mi : float
Floating value describing the mutual-information between x and y.
"""
# n_samples
n = float(x.shape[1])

_n_neighbours = partial(n_neighbours, k=k)
# get number of neighbors for each point in XxY space
_, n_neighbors = jax.lax.scan(
_n_neighbours, (x, y), jnp.arange(int(n)).astype(int)
)
nx = n_neighbors[0]
ny = n_neighbors[1]

psi_mean = jnp.sum((psi(nx + 1) + psi(ny + 1)) / n)

mi = psi(k) - psi_mean + psi(n)
return mi / jnp.log(2)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[tool.black]
line-length = 79

[tool.ruff]
line-length = 79

0 comments on commit c6edaa6

Please sign in to comment.