From b1dd2b8a76ce942908216a3d968a63b07414d800 Mon Sep 17 00:00:00 2001 From: Christian Ferreyra Date: Sun, 7 Apr 2024 12:41:46 +0200 Subject: [PATCH] fixed mi knn estimator by adding compute_mi_knn --- .gitignore | 4 +++ hoi/core/mi.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++---- pyproject.toml | 3 ++ 3 files changed, 83 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 0166c8e9..41b9a5b7 100644 --- a/.gitignore +++ b/.gitignore @@ -110,6 +110,10 @@ ENV/ env.bak/ venv.bak/ +# Pipenv +Pipfile +Pipfile.lock + # Spyder project settings .spyderproject .spyproject diff --git a/hoi/core/mi.py b/hoi/core/mi.py index 9229e98e..55b3d55d 100644 --- a/hoi/core/mi.py +++ b/hoi/core/mi.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp +from jax.scipy.special import digamma as psi from .entropies import get_entropy @@ -28,11 +29,14 @@ def get_mi(method="gcmi", **kwargs): Function to compute mutual information on variables of shapes (n_features, n_samples) """ - # get the entropy unction - _entropy = get_entropy(method=method, **kwargs) + if method == "knn": + return partial(compute_mi_knn, **kwargs) + else: + # get the entropy function + _entropy = get_entropy(method=method, **kwargs) - # wrap the mi function with it - return partial(compute_mi, entropy_fcn=_entropy) + # wrap the mi function with it + return partial(compute_mi, entropy_fcn=_entropy) ############################################################################### @@ -57,7 +61,7 @@ def compute_mi_comb(inputs, comb, mi=None): ############################################################################### ############################################################################### -# OTHERS +# GENERAL MUTUAL INFORMATION ############################################################################### ############################################################################### @@ -87,3 +91,70 @@ def compute_mi(x, y, entropy_fcn=None): - entropy_fcn(jnp.concatenate((x, y), axis=0)) ) return mi + + +############################################################################### +############################################################################### +# KNN MUTUAL INFORMATION +############################################################################### +############################################################################### + + +@partial(jax.jit, static_argnums=(2,)) +def n_neighbours(xy, idx, k=1): + """Return number of neighbours for each point based on kth neighbour.""" + xi, x = xy[0][:, [idx]], xy[0] + yi, y = xy[1][:, [idx]], xy[1] + + # compute euclidian distance from xi to all points in x (same y) + eucl_xi = jnp.sqrt(jnp.sum((xi - x) ** 2, axis=0)) + eucl_yi = jnp.sqrt(jnp.sum((yi - y) ** 2, axis=0)) + + # distance in space (XxY) is the maximum distance. + max_dist_xy = jnp.maximum(eucl_xi, eucl_yi) + # indices to the closest points in the (XxY) space. + closest_points = jnp.argsort(max_dist_xy) + # the kth neighbour is at index k (ignoring the point itself) + # distance to the k-th neighbor for each point + dist_k = max_dist_xy[closest_points[k]] + # don't include the `i`th point itself in nx and ny + nx = (eucl_xi < dist_k).sum() - 1 + ny = (eucl_yi < dist_k).sum() - 1 + + return xy, (nx, ny) + + +@partial(jax.jit, static_argnums=(2,)) +def compute_mi_knn(x, y, k: int = 1) -> jnp.array: + """Mutual information using the KSG estimator. + + First algorithm proposed in Kraskov et al., Estimating mutual information, + Phy rev, 2004. + + Parameters + ---------- + x, y : array_like + Input data of shape (n_features, n_samples). + k : int + Number of nearest neighbors to consider for the KSG estimator. + + Returns + ------- + mi : float + Floating value describing the mutual-information between x and y. + """ + # n_samples + n = float(x.shape[1]) + + _n_neighbours = partial(n_neighbours, k=k) + # get number of neighbors for each point in XxY space + _, n_neighbors = jax.lax.scan( + _n_neighbours, (x, y), jnp.arange(int(n)).astype(int) + ) + nx = n_neighbors[0] + ny = n_neighbors[1] + + psi_mean = jnp.sum((psi(nx + 1) + psi(ny + 1)) / n) + + mi = psi(k) - psi_mean + psi(n) + return mi / jnp.log(2) diff --git a/pyproject.toml b/pyproject.toml index 9216134b..f4362d4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,5 @@ [tool.black] +line-length = 79 + +[tool.ruff] line-length = 79 \ No newline at end of file