Skip to content

Commit 6eb96b1

Browse files
committed
Adapt metrics to vmap over the last axis
1 parent 3946e19 commit 6eb96b1

File tree

11 files changed

+43
-35
lines changed

11 files changed

+43
-35
lines changed

hoi/core/entropies.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
###############################################################################
2323

2424

25-
def get_entropy(method="gcmi", **kwargs):
25+
def get_entropy(method="gcmi", vmap=False, **kwargs):
2626
"""Get entropy function.
2727
2828
Parameters
@@ -43,7 +43,10 @@ def get_entropy(method="gcmi", **kwargs):
4343
elif method == "binning":
4444
return partial(entropy_bin, **kwargs)
4545
elif method == "knn":
46-
return partial(entropy_knn, **kwargs)
46+
# wrap distance funtion with k
47+
k = kwargs.get("k", 1)
48+
cdist = partial(cdistk, k=k)
49+
return partial(_entropy_knn, cdist=cdist, k=k)
4750
elif method == "kernel":
4851
return partial(entropy_kernel, **kwargs)
4952
else:
@@ -57,7 +60,7 @@ def get_entropy(method="gcmi", **kwargs):
5760
###############################################################################
5861

5962

60-
def prepare_for_entropy(data, method, reshape=True, **kwargs):
63+
def prepare_for_entropy(data, method, **kwargs):
6164
"""Prepare the data before computing entropy."""
6265
# data.shape = n_samples, n_features, n_variables
6366

@@ -84,10 +87,6 @@ def prepare_for_entropy(data, method, reshape=True, **kwargs):
8487
elif method == "binning":
8588
pass
8689

87-
# make the data (n_variables, n_features, n_samples)
88-
if reshape:
89-
data = jnp.asarray(data.transpose(2, 1, 0))
90-
9190
return data, kwargs
9291

9392

@@ -110,7 +109,7 @@ def entropy_gcmi(
110109
Parameters
111110
----------
112111
x : array_like
113-
Array of data of shape (n_features, n_samples)
112+
Array of data of shape (n_samples, n_features)
114113
biascorrect : bool | False
115114
Specifies whether bias correction should be applied to the estimated MI
116115
demean : bool | False
@@ -121,14 +120,14 @@ def entropy_gcmi(
121120
hx : float
122121
Entropy of the gaussian variable (in bits)
123122
"""
124-
nfeat, nsamp = x.shape
123+
nsamp, nfeat = x.shape
125124

126125
# demean data
127126
if demean:
128-
x = x - x.mean(axis=1, keepdims=True)
127+
x = x - x.mean(axis=0, keepdims=True)
129128

130129
# covariance
131-
c = jnp.dot(x, x.T) / float(nsamp - 1)
130+
c = jnp.dot(x.T, x) / float(nsamp - 1)
132131
chc = jnp.linalg.cholesky(c)
133132

134133
# entropy in nats
@@ -218,7 +217,7 @@ def entropy_bin(x: jnp.array, base: int = 2) -> jnp.array:
218217
Parameters
219218
----------
220219
x : array_like
221-
Input data of shape (n_features, n_samples). The data should already
220+
Input data of shape (n_samples, n_features). The data should already
222221
be discretize
223222
base : int | 2
224223
The logarithmic base to use. Default is base 2.
@@ -228,13 +227,13 @@ def entropy_bin(x: jnp.array, base: int = 2) -> jnp.array:
228227
hx : float
229228
Entropy of x
230229
"""
231-
n_features, n_samples = x.shape
230+
n_samples, n_features = x.shape
232231
# here, we count the number of possible multiplets. The worst is that each
233232
# trial is unique. So we can prepare the output to be at most (n_samples,)
234233
# and if trials are repeated, just set to zero it's going to be compensated
235234
# by the entr() function
236235
counts = jnp.unique(
237-
x, return_counts=True, size=n_samples, axis=1, fill_value=0
236+
x, return_counts=True, size=n_samples, axis=0, fill_value=0
238237
)[1]
239238
probs = counts / n_samples
240239
return jax.scipy.special.entr(probs).sum() / np.log(base)
@@ -257,10 +256,10 @@ def set_to_inf(x, _):
257256
@partial(jax.jit, static_argnums=(2,))
258257
def cdistk(xx, idx, k=1):
259258
"""K-th minimum euclidian distance."""
260-
x, y = xx[:, [idx]], xx
259+
x, y = xx[[idx], :], xx
261260

262261
# compute euclidian distance
263-
eucl = jnp.sqrt(jnp.sum((x - y) ** 2, axis=0))
262+
eucl = jnp.sqrt(jnp.sum((x - y) ** 2, axis=1))
264263

265264
# in case of 0-distances, replace them by infinity
266265
eucl = jnp.where(eucl == 0, jnp.inf, eucl)
@@ -271,8 +270,8 @@ def cdistk(xx, idx, k=1):
271270
return xx, eucl[jnp.argmin(eucl)]
272271

273272

274-
@partial(jax.jit, static_argnums=(1,))
275-
def entropy_knn(x: jnp.array, k: int = 1) -> jnp.array:
273+
@partial(jax.jit, static_argnums=(1, 2))
274+
def entropy_knn(x: jnp.array, k: int = 1, cdist=None) -> jnp.array:
276275
"""Entropy using the k-nearest neighbor.
277276
278277
Original code: https://github.com/blakeaw/Python-knn-entropy/
@@ -292,6 +291,14 @@ def entropy_knn(x: jnp.array, k: int = 1) -> jnp.array:
292291
hx : float
293292
Entropy of x
294293
"""
294+
# wrap cdist
295+
cdist = partial(cdistk, k=k)
296+
fcn = partial(_entropy_knn, cdist=cdist, k=k)
297+
return fcn(x)
298+
299+
300+
@partial(jax.jit, static_argnums=(1, 2))
301+
def _entropy_knn(x: jnp.array, k: int = 1, cdist=None) -> jnp.array:
295302
# x = jnp.atleast_2d(x)
296303
d, n = float(x.shape[0]), float(x.shape[1])
297304

@@ -330,14 +337,15 @@ def entropy_kernel(
330337
Parameters
331338
----------
332339
x : array_like
333-
Input data of shape (n_features, n_samples)
340+
Input data of shape (n_samples, n_features)
334341
335342
Returns
336343
-------
337344
hx : float
338345
Entropy of x
339346
"""
340-
model = gaussian_kde(x, bw_method=bw_method)
341-
return -jnp.mean(jnp.log2(model(x)))
347+
x_t = x.T
348+
model = gaussian_kde(x_t, bw_method=bw_method)
349+
return -jnp.mean(jnp.log2(model(x_t)))
342350
# p = model.pdf(x)
343351
# return jax.scipy.special.entr(p).sum() / np.log(base)

hoi/core/mi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def compute_mi(x, y, entropy_fcn=None):
7070
----------
7171
x, y : array_like
7272
Arrays to consider for computing the Mutual Information. The two input
73-
variables x and y should have a shape of (n_features_x, n_samples) and
74-
(n_features_y, n_samples)
73+
variables x and y should have a shape of (n_samples, n_features_x) and
74+
(n_samples, n_features_y)
7575
entropy_fcn : function | None
7676
Function to use for computing the entropy.
7777
@@ -84,6 +84,6 @@ def compute_mi(x, y, entropy_fcn=None):
8484
mi = (
8585
entropy_fcn(x)
8686
+ entropy_fcn(y)
87-
- entropy_fcn(jnp.concatenate((x, y), axis=0))
87+
- entropy_fcn(jnp.concatenate((x, y), axis=1))
8888
)
8989
return mi

hoi/metrics/base_hoi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def compute_entropies(
182182
# get entropy function
183183
entropy = partial(
184184
ent_at_index,
185-
entropy=jax.vmap(get_entropy(method=method, **kwargs)),
185+
entropy=jax.vmap(get_entropy(method=method, **kwargs), in_axes=2),
186186
)
187187

188188
# ______________________________ ENTROPY ______________________________

hoi/metrics/dtc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
110110
x, kwargs = prepare_for_entropy(self._x, method, **kwargs)
111111

112112
# get entropy function
113-
entropy = jax.vmap(get_entropy(method=method, **kwargs))
113+
entropy = jax.vmap(get_entropy(method=method, **kwargs), in_axes=2)
114114
dtc_no_ent = partial(
115115
_dtc_no_ent,
116116
entropy_3d=entropy,

hoi/metrics/info_tot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
7777
x, y = self._split_xy(x)
7878

7979
# prepare mi functions
80-
mi_fcn = jax.vmap(get_mi(method=method, **kwargs))
80+
mi_fcn = jax.vmap(get_mi(method=method, **kwargs), in_axes=2)
8181
compute_mi = partial(compute_mi_comb, mi=mi_fcn)
8282

8383
# get multiplet indices and order

hoi/metrics/oinfo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ def _oinfo_no_ent(inputs, index, entropy_3d=None, entropy_4d=None):
2121
# compute h(x^{n})
2222
h_xn = entropy_3d(x_c)
2323

24-
# compute \sum_{j=1}^{n} h(x_{j}
24+
# compute \sum_{j=1}^{n} h(x_{j})
2525
h_xj_sum = entropy_4d(x_c[:, :, jnp.newaxis, :]).sum(0)
2626

27-
# compute \sum_{j=1}^{n} h(x_{-j}
27+
# compute \sum_{j=1}^{n} h(x_{-j})
2828
h_xmj_sum = entropy_4d(x_c[:, acc, :]).sum(0)
2929

3030
# compute oinfo
@@ -115,7 +115,7 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
115115
x, kwargs = prepare_for_entropy(self._x, method, **kwargs)
116116

117117
# get entropy function
118-
entropy = jax.vmap(get_entropy(method=method, **kwargs))
118+
entropy = jax.vmap(get_entropy(method=method, **kwargs), in_axes=2)
119119
oinfo_no_ent = partial(
120120
_oinfo_no_ent,
121121
entropy_3d=entropy,

hoi/metrics/red_mmi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
7474
x, y = self._split_xy(x)
7575

7676
# prepare mi functions
77-
mi_fcn = jax.vmap(get_mi(method=method, **kwargs))
77+
mi_fcn = jax.vmap(get_mi(method=method, **kwargs), in_axes=2)
7878
compute_mi = partial(compute_mi_comb, mi=mi_fcn)
7979

8080
# get multiplet indices and order

hoi/metrics/rsi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
101101
x, y = self._split_xy(x)
102102

103103
# prepare mi functions
104-
mi_fcn = jax.vmap(get_mi(method=method, **kwargs))
104+
mi_fcn = jax.vmap(get_mi(method=method, **kwargs), in_axes=2)
105105
compute_mi = partial(compute_mi_comb, mi=mi_fcn)
106106

107107
# get multiplet indices and order

hoi/metrics/sinfo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
111111
x, kwargs = prepare_for_entropy(self._x, method, **kwargs)
112112

113113
# get entropy function
114-
entropy = jax.vmap(get_entropy(method=method, **kwargs))
114+
entropy = jax.vmap(get_entropy(method=method, **kwargs), in_axes=2)
115115
sinfo_no_ent = partial(
116116
_sinfo_no_ent,
117117
entropy_3d=entropy,

hoi/metrics/syn_mmi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
9090
x, y = self._split_xy(x)
9191

9292
# prepare mi functions
93-
mi_fcn = jax.vmap(get_mi(method=method, **kwargs))
93+
mi_fcn = jax.vmap(get_mi(method=method, **kwargs), in_axes=2)
9494
compute_mi = partial(compute_mi_comb, mi=mi_fcn)
9595
compute_syn = partial(_compute_syn, mi_fcn=compute_mi)
9696

hoi/metrics/tc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
110110
x, kwargs = prepare_for_entropy(self._x, method, **kwargs)
111111

112112
# get entropy function
113-
entropy = jax.vmap(get_entropy(method=method, **kwargs))
113+
entropy = jax.vmap(get_entropy(method=method, **kwargs), in_axes=2)
114114
tc_no_ent = partial(
115115
_tc_no_ent,
116116
entropy_3d=entropy,

0 commit comments

Comments
 (0)