Skip to content

Commit 52321ac

Browse files
committed
Adapt all metrics + fix unit test
1 parent 20f5f5d commit 52321ac

File tree

9 files changed

+249
-108
lines changed

9 files changed

+249
-108
lines changed

hoi/core/tests/test_combinatory.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,79 @@
11
import pytest
2+
23
import numpy as np
4+
import jax.numpy as jnp
5+
36
from math import comb as ccomb
47
from hoi.core.combinatory import combinations, _combinations
58
from collections.abc import Iterable
69

710

811
class TestCombinatory(object):
12+
@pytest.mark.parametrize("target", [[], [20], [20, 21]])
913
@pytest.mark.parametrize(
1014
"n", [np.random.randint(5, 10) for _ in range(10)]
1115
)
1216
@pytest.mark.parametrize(
1317
"k", [np.random.randint(5, 10) for _ in range(10)]
1418
)
1519
@pytest.mark.parametrize("order", [True, False])
16-
def test_single_combinations(self, n, k, order):
17-
c = list(_combinations(n, k, order))
20+
def test_single_combinations(self, n, k, order, target):
21+
c = list(_combinations(n, k, order, target))
22+
23+
# test that the number of combinations is correct
1824
assert len(c) == ccomb(n, k)
19-
pass
2025

21-
@pytest.mark.parametrize("n", [np.random.randint(5, 10) for _ in range(2)])
22-
@pytest.mark.parametrize(
23-
"min", [np.random.randint(1, 10) for _ in range(2)]
24-
)
26+
# check the order
27+
if order:
28+
assert all([o == k + len(target) for o in c])
29+
else:
30+
assert all([len(o) == k + len(target) for o in c])
31+
32+
# check that targets are included
33+
if len(target) and not order:
34+
assert all([all([m in o for m in target]) for o in c])
35+
36+
@pytest.mark.parametrize("fill", [-1, -10])
37+
@pytest.mark.parametrize("target", [None, [20], [20, 21]])
38+
@pytest.mark.parametrize("order", [True, False])
39+
@pytest.mark.parametrize("astype", ["numpy", "jax", "iterator"])
2540
@pytest.mark.parametrize(
2641
"max", [_ for _ in range(2)]
2742
) # addition to minimum size
28-
@pytest.mark.parametrize("astype", ["numpy", "jax", "iterator"])
29-
@pytest.mark.parametrize("order_val", [True, False])
30-
def test_combinations(self, n, min, max, astype, order_val):
31-
combs = combinations(n, min, min + max, astype, order_val)
32-
assert isinstance(combs, Iterable)
33-
pass
43+
@pytest.mark.parametrize(
44+
"min", [np.random.randint(1, 10) for _ in range(2)]
45+
)
46+
@pytest.mark.parametrize("n", [np.random.randint(5, 10) for _ in range(2)])
47+
def test_combinations(self, n, min, max, astype, order, target, fill):
48+
# get combinations
49+
combs = combinations(
50+
n,
51+
min,
52+
maxsize=min + max,
53+
astype=astype,
54+
order=order,
55+
target=target,
56+
fill_value=fill,
57+
)
58+
59+
# check the number of multiplets
60+
n_mults = 0
61+
for c in range(min, min + max + 1):
62+
n_mults += ccomb(n, c)
63+
if astype in ["jax", "numpy"]:
64+
assert combs.shape[0] == n_mults
65+
elif astype == "iterator":
66+
assert len([c for c in combs]) == n_mults
67+
68+
# check type
69+
if astype == "numpy":
70+
assert isinstance(combs, np.ndarray)
71+
elif astype == "jax":
72+
assert isinstance(combs, jnp.ndarray)
73+
elif astype == "iterator":
74+
assert isinstance(combs, Iterable)
75+
76+
77+
if __name__ == "__main__":
78+
# TestCombinatory().test_single_combinations(5, 3, False, [21, 22])
79+
TestCombinatory().test_combinations(10, 2, 3, "iterator", False, None, -1)

hoi/metrics/base_hoi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,6 @@ def compute_entropies(
216216
pbar.close()
217217

218218
self._entropies = h_x
219-
self._multiplets = h_idx
220-
self._order = order
221219

222220
return h_x, h_idx, order
223221

hoi/metrics/gradient_oinfo.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,17 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
8080
The NumPy array containing values of higher-rder interactions of
8181
shape (n_multiplets, n_variables)
8282
"""
83-
# ____________________________ TASK-FREE ______________________________
84-
hoi_tf = self._oinf_tf.fit(
83+
kw_oinfo = dict(
8584
minsize=minsize, maxsize=maxsize, method=method, **kwargs
8685
)
8786

87+
# ____________________________ TASK-FREE ______________________________
88+
hoi_tf = self._oinf_tf.fit(**kw_oinfo)
89+
8890
self._multiplets = self._oinf_tf._multiplets
8991

9092
# __________________________ TASK-RELATED _____________________________
91-
hoi_tr = self._oinf_tr.fit(
92-
minsize=minsize,
93-
maxsize=maxsize,
94-
method=method,
95-
**kwargs
96-
)
93+
hoi_tr = self._oinf_tr.fit(**kw_oinfo)
9794

9895
return hoi_tr - hoi_tf
9996

@@ -118,4 +115,4 @@ def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
118115
model = GradientOinfo(x, y=y)
119116
hoi = model.fit(minsize=2, maxsize=None, method="gcmi")
120117

121-
print(get_nbest_mult(hoi, model=model, minsize=3, maxsize=3, n_best=3))
118+
print(get_nbest_mult(hoi, model=model, minsize=3, maxsize=3, n_best=3))

hoi/metrics/info_tot.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from math import comb as ccomb
21
from functools import partial
32

43
import numpy as np
@@ -7,7 +6,6 @@
76
import jax.numpy as jnp
87

98
from hoi.metrics.base_hoi import HOIEstimator
10-
from hoi.core.combinatory import combinations
119
from hoi.core.entropies import prepare_for_entropy
1210
from hoi.core.mi import get_mi, compute_mi_comb
1311
from hoi.utils.progressbar import get_pbar

hoi/metrics/infotopo.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import numpy as np
24

35
import jax
@@ -64,12 +66,6 @@ class InfoTopo(HOIEstimator):
6466
x : array_like
6567
Standard NumPy arrays of shape (n_samples, n_features) or
6668
(n_samples, n_features, n_variables)
67-
y : array_like
68-
The feature of shape (n_samples,) for estimating task-related O-info.
69-
multiplets : list | None
70-
List of multiplets to compute. Should be a list of multiplets, for
71-
example [(0, 1, 2), (2, 7, 8, 9)]. By default, all multiplets are
72-
going to be computed.
7369
7470
References
7571
----------
@@ -83,9 +79,14 @@ class InfoTopo(HOIEstimator):
8379
_negative = "synergy"
8480
_symmetric = True
8581

86-
def __init__(self, x, y=None, multiplets=None, verbose=None):
82+
def __init__(self, x, y=None, verbose=None):
83+
# for infotopo, the multiplets are set to None because this metric
84+
# first require to compute entropies and then associate them to form
85+
# the MI. Same for the target y.
86+
if y is not None:
87+
warnings.warn("For InfoTopo, y input is going to be ignored.")
8788
HOIEstimator.__init__(
88-
self, x=x, y=y, multiplets=multiplets, verbose=verbose
89+
self, x=x, y=None, multiplets=None, verbose=verbose
8990
)
9091

9192
def fit(self, minsize=1, maxsize=None, method="gcmi", **kwargs):
@@ -120,20 +121,20 @@ def fit(self, minsize=1, maxsize=None, method="gcmi", **kwargs):
120121
# ____________________________ ENTROPIES ______________________________
121122

122123
minsize, maxsize = self._check_minmax(minsize, maxsize)
123-
h_x, h_idx, order = self.compute_entropies(
124+
h_x, h_idx, _ = self.compute_entropies(
124125
minsize=1, maxsize=maxsize, method=method, **kwargs
125126
)
126-
n_mult = h_x.shape[0]
127127

128128
# _______________________________ HOI _________________________________
129129

130130
# compute order and multiply entropies
131+
order = (h_idx >= 0).sum(1)
131132
h_x_sgn = jnp.multiply(((-1.0) ** (order.reshape(-1, 1) - 1)), h_x)
132133

133134
# subselection of multiplets
134135
mults, _ = self.get_combinations(minsize, maxsize=maxsize)
135136
h_idx_2 = jnp.where(mults == -1, -2, mults)
136-
n_mult = h_idx_2.shape[0]
137+
n_mult = mults.shape[0]
137138

138139
# progress-bar definition
139140
pbar = scan_tqdm(n_mult, message="Mutual information")
@@ -150,28 +151,23 @@ def fit(self, minsize=1, maxsize=None, method="gcmi", **kwargs):
150151

151152
if __name__ == "__main__":
152153
import matplotlib.pyplot as plt
153-
from hoi.utils import landscape, get_nbest_mult
154-
from matplotlib.colors import LogNorm
154+
from hoi.utils import get_nbest_mult
155+
from hoi.plot import plot_landscape
155156

156157
plt.style.use("ggplot")
157158

158159
x = np.random.rand(200, 7)
159160
y_red = np.random.rand(x.shape[0])
160161

161-
# redundancy: (1, 2, 6) + (7, 8)
162-
x[:, 1] += y_red
163-
x[:, 2] += y_red
164-
x[:, 6] += y_red
162+
# redundancy: (1, 2, 6)
163+
x[:, 2] += x[:, 1]
164+
x[:, 6] += x[:, 2]
165165
# synergy: (0, 3, 5) + (7, 8)
166-
y_syn = x[:, 0] + x[:, 3] + x[:, 5]
167-
# bivariate target
168-
y = np.c_[y_red, y_syn]
166+
x[:, 0] = x[:, 0] + x[:, 3] + x[:, 5]
169167

170-
model = InfoTopo(x, y=y)
171-
hoi = model.fit(maxsize=None, method="gcmi")
168+
model = InfoTopo(x)
169+
hoi = model.fit(minsize=3, maxsize=5, method="gcmi")
172170
print(get_nbest_mult(hoi, model=model, minsize=3, maxsize=3, n_best=3))
173171

174-
lscp = landscape(hoi.squeeze(), model.order, output="xarray")
175-
lscp.plot(x="order", y="bins", cmap="jet", norm=LogNorm())
176-
plt.axvline(model.undersampling, linestyle="--", color="k")
172+
plot_landscape(hoi, model=model)
177173
plt.show()

hoi/metrics/red_mmi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from math import comb as ccomb
21
from functools import partial
32

43
import numpy as np
@@ -7,7 +6,6 @@
76
import jax.numpy as jnp
87

98
from hoi.metrics.base_hoi import HOIEstimator
10-
from hoi.core.combinatory import combinations
119
from hoi.core.entropies import prepare_for_entropy
1210
from hoi.core.mi import get_mi, compute_mi_comb
1311
from hoi.utils.progressbar import get_pbar

hoi/metrics/rsi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from math import comb as ccomb
21
from functools import partial
32

43
import numpy as np
@@ -7,7 +6,6 @@
76
import jax.numpy as jnp
87

98
from hoi.metrics.base_hoi import HOIEstimator
10-
from hoi.core.combinatory import combinations
119
from hoi.core.entropies import prepare_for_entropy
1210
from hoi.core.mi import get_mi, compute_mi_comb
1311
from hoi.utils.progressbar import get_pbar

hoi/metrics/syn_mmi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from math import comb as ccomb
21
from functools import partial
32

43
import numpy as np
@@ -7,7 +6,6 @@
76
import jax.numpy as jnp
87

98
from hoi.metrics.base_hoi import HOIEstimator
10-
from hoi.core.combinatory import combinations
119
from hoi.core.entropies import prepare_for_entropy
1210
from hoi.core.mi import get_mi, compute_mi_comb
1311
from hoi.utils.progressbar import get_pbar

0 commit comments

Comments
 (0)