Skip to content

Commit 8e53b76

Browse files
kzkadcvfdev-5
andauthored
add HSIC metric (#3282)
* add HSIC metric * minor update on docstring * add reference to the HSIC formula in docstring * update version directive * fix formatting issue * add type hints * accumulate HSIC value for each batch * update test to clip value for each batch * fix accumulator device error * fix error in making y * fix test to use the same linear layer across metric_devices * Revert "fix test to use the same linear layer across metric_devices" This reverts commit cb71355. * Fixed distributed tests * Fixed code formatting errors --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 9481227 commit 8e53b76

File tree

4 files changed

+361
-0
lines changed

4 files changed

+361
-0
lines changed

docs/source/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ Complete list of metrics
357357
KLDivergence
358358
JSDivergence
359359
MaximumMeanDiscrepancy
360+
HSIC
360361
AveragePrecision
361362
CohenKappa
362363
GpuInfo

ignite/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ignite.metrics.gan.fid import FID
1515
from ignite.metrics.gan.inception_score import InceptionScore
1616
from ignite.metrics.gpu_info import GpuInfo
17+
from ignite.metrics.hsic import HSIC
1718
from ignite.metrics.js_divergence import JSDivergence
1819
from ignite.metrics.kl_divergence import KLDivergence
1920
from ignite.metrics.loss import Loss
@@ -64,6 +65,7 @@
6465
"JaccardIndex",
6566
"JSDivergence",
6667
"KLDivergence",
68+
"HSIC",
6769
"MaximumMeanDiscrepancy",
6870
"MultiLabelConfusionMatrix",
6971
"MutualInformation",

ignite/metrics/hsic.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from typing import Callable, Sequence, Union
2+
3+
import torch
4+
from torch import Tensor
5+
6+
from ignite.exceptions import NotComputableError
7+
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
8+
9+
__all__ = ["HSIC"]
10+
11+
12+
class HSIC(Metric):
13+
r"""Calculates the `Hilbert-Schmidt Independence Criterion (HSIC)
14+
<https://papers.nips.cc/paper_files/paper/2007/hash/d5cfead94f5350c12c322b5b664544c1-Abstract.html>`_.
15+
16+
.. math::
17+
\text{HSIC}(X,Y) = \frac{1}{B(B-3)}\left[ \text{tr}(\tilde{\mathbf{K}}\tilde{\mathbf{L}})
18+
+ \frac{\mathbf{1}^\top \tilde{\mathbf{K}} \mathbf{11}^\top \tilde{\mathbf{L}} \mathbf{1}}{(B-1)(B-2)}
19+
-\frac{2}{B-2}\mathbf{1}^\top \tilde{\mathbf{K}}\tilde{\mathbf{L}} \mathbf{1} \right]
20+
21+
where :math:`B` is the batch size, and :math:`\tilde{\mathbf{K}}`
22+
and :math:`\tilde{\mathbf{L}}` are the Gram matrices of
23+
the Gaussian RBF kernel with their diagonal entries being set to zero.
24+
25+
HSIC measures non-linear statistical independence between features :math:`X` and :math:`Y`.
26+
HSIC becomes zero if and only if :math:`X` and :math:`Y` are independent.
27+
28+
This metric computes the unbiased estimator of HSIC proposed in
29+
`Song et al. (2012) <https://jmlr.csail.mit.edu/papers/v13/song12a.html>`_.
30+
The HSIC is estimated using Eq. (5) of the paper for each batch and the average is accumulated.
31+
32+
Each batch must contain at least four samples.
33+
34+
- ``update`` must receive output of the form ``(y_pred, y)``.
35+
36+
Args:
37+
sigma_x: bandwidth of the kernel for :math:`X`.
38+
If negative, a heuristic value determined by the median of the distances between
39+
the samples is used. Default: -1
40+
sigma_y: bandwidth of the kernel for :math:`Y`.
41+
If negative, a heuristic value determined by the median of the distances
42+
between the samples is used. Default: -1
43+
ignore_invalid_batch: If ``True``, computation for a batch with less than four samples is skipped.
44+
If ``False``, ``ValueError`` is raised when received such a batch.
45+
output_transform: a callable that is used to transform the
46+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
47+
form expected by the metric. This can be useful if, for example, you have a multi-output model and
48+
you want to compute the metric with respect to one of the outputs.
49+
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
50+
device: specifies which device updates are accumulated on. Setting the
51+
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
52+
non-blocking. By default, CPU.
53+
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
54+
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
55+
Alternatively, ``output_transform`` can be used to handle this.
56+
57+
Examples:
58+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
59+
The output of the engine's ``process_function`` needs to be in the format of
60+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
61+
to the metric to transform the output into the form expected by the metric.
62+
63+
``y_pred`` and ``y`` should have the same shape.
64+
65+
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
66+
67+
.. include:: defaults.rst
68+
:start-after: :orphan:
69+
70+
.. testcode::
71+
72+
metric = HSIC()
73+
metric.attach(default_evaluator, "hsic")
74+
X = torch.tensor([[0., 1., 2., 3., 4.],
75+
[5., 6., 7., 8., 9.],
76+
[10., 11., 12., 13., 14.],
77+
[15., 16., 17., 18., 19.],
78+
[20., 21., 22., 23., 24.],
79+
[25., 26., 27., 28., 29.],
80+
[30., 31., 32., 33., 34.],
81+
[35., 36., 37., 38., 39.],
82+
[40., 41., 42., 43., 44.],
83+
[45., 46., 47., 48., 49.]])
84+
Y = torch.sin(X * torch.pi * 2 / 50)
85+
state = default_evaluator.run([[X, Y]])
86+
print(state.metrics["hsic"])
87+
88+
.. testoutput::
89+
90+
0.09226646274328232
91+
92+
.. versionadded:: 0.5.2
93+
"""
94+
95+
def __init__(
96+
self,
97+
sigma_x: float = -1,
98+
sigma_y: float = -1,
99+
ignore_invalid_batch: bool = True,
100+
output_transform: Callable = lambda x: x,
101+
device: Union[str, torch.device] = torch.device("cpu"),
102+
skip_unrolling: bool = False,
103+
):
104+
super().__init__(output_transform, device, skip_unrolling=skip_unrolling)
105+
106+
self.sigma_x = sigma_x
107+
self.sigma_y = sigma_y
108+
self.ignore_invalid_batch = ignore_invalid_batch
109+
110+
_state_dict_all_req_keys = ("_sum_of_hsic", "_num_batches")
111+
112+
@reinit__is_reduced
113+
def reset(self) -> None:
114+
self._sum_of_hsic = torch.tensor(0.0, device=self._device)
115+
self._num_batches = 0
116+
117+
@reinit__is_reduced
118+
def update(self, output: Sequence[Tensor]) -> None:
119+
X = output[0].detach().flatten(start_dim=1)
120+
Y = output[1].detach().flatten(start_dim=1)
121+
b = X.shape[0]
122+
123+
if b <= 3:
124+
if self.ignore_invalid_batch:
125+
return
126+
else:
127+
raise ValueError(f"A batch must contain more than four samples, got only {b} samples.")
128+
129+
mask = 1.0 - torch.eye(b, device=X.device)
130+
131+
xx = X @ X.T
132+
rx = xx.diag().unsqueeze(0).expand_as(xx)
133+
dxx = rx.T + rx - xx * 2
134+
135+
vx: Union[Tensor, float]
136+
if self.sigma_x < 0:
137+
vx = torch.quantile(dxx, 0.5)
138+
else:
139+
vx = self.sigma_x**2
140+
K = torch.exp(-0.5 * dxx / vx) * mask
141+
142+
yy = Y @ Y.T
143+
ry = yy.diag().unsqueeze(0).expand_as(yy)
144+
dyy = ry.T + ry - yy * 2
145+
146+
vy: Union[Tensor, float]
147+
if self.sigma_y < 0:
148+
vy = torch.quantile(dyy, 0.5)
149+
else:
150+
vy = self.sigma_y**2
151+
L = torch.exp(-0.5 * dyy / vy) * mask
152+
153+
KL = K @ L
154+
trace = KL.trace()
155+
second_term = K.sum() * L.sum() / ((b - 1) * (b - 2))
156+
third_term = KL.sum() / (b - 2)
157+
158+
hsic = trace + second_term - third_term * 2.0
159+
hsic /= b * (b - 3)
160+
hsic = torch.clamp(hsic, min=0.0) # HSIC must not be negative
161+
self._sum_of_hsic += hsic.to(self._device)
162+
163+
self._num_batches += 1
164+
165+
@sync_all_reduce("_sum_of_hsic", "_num_batches")
166+
def compute(self) -> float:
167+
if self._num_batches == 0:
168+
raise NotComputableError("HSIC must have at least one batch before it can be computed.")
169+
170+
return self._sum_of_hsic.item() / self._num_batches

tests/ignite/metrics/test_hsic.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from typing import Tuple
2+
3+
import numpy as np
4+
import pytest
5+
6+
import torch
7+
from torch import nn, Tensor
8+
9+
import ignite.distributed as idist
10+
from ignite.engine import Engine
11+
from ignite.exceptions import NotComputableError
12+
from ignite.metrics import HSIC
13+
14+
15+
def np_hsic(x: Tensor, y: Tensor, sigma_x: float = -1, sigma_y: float = -1) -> float:
16+
x_np = x.detach().cpu().numpy()
17+
y_np = y.detach().cpu().numpy()
18+
b = x_np.shape[0]
19+
20+
ii, jj = np.meshgrid(np.arange(b), np.arange(b), indexing="ij")
21+
mask = 1.0 - np.eye(b)
22+
23+
dxx = np.square(x_np[ii] - x_np[jj]).sum(axis=2)
24+
if sigma_x < 0:
25+
vx = np.median(dxx)
26+
else:
27+
vx = sigma_x * sigma_x
28+
K = np.exp(-0.5 * dxx / vx) * mask
29+
30+
dyy = np.square(y_np[ii] - y_np[jj]).sum(axis=2)
31+
if sigma_y < 0:
32+
vy = np.median(dyy)
33+
else:
34+
vy = sigma_y * sigma_y
35+
L = np.exp(-0.5 * dyy / vy) * mask
36+
37+
KL = K @ L
38+
ones = np.ones(b)
39+
hsic = np.trace(KL) + (ones @ K @ ones) * (ones @ L @ ones) / ((b - 1) * (b - 2)) - ones @ KL @ ones * 2 / (b - 2)
40+
hsic /= b * (b - 3)
41+
hsic = np.clip(hsic, 0.0, None)
42+
return hsic
43+
44+
45+
def test_zero_batch():
46+
hsic = HSIC()
47+
with pytest.raises(NotComputableError, match=r"HSIC must have at least one batch before it can be computed"):
48+
hsic.compute()
49+
50+
51+
def test_invalid_batch():
52+
hsic = HSIC(ignore_invalid_batch=False)
53+
X = torch.tensor([[1, 2, 3]]).float()
54+
Y = torch.tensor([[4, 5, 6]]).float()
55+
with pytest.raises(ValueError, match=r"A batch must contain more than four samples, got only"):
56+
hsic.update((X, Y))
57+
58+
59+
@pytest.fixture(params=[0, 1, 2])
60+
def test_case(request) -> Tuple[Tensor, Tensor, int]:
61+
if request.param == 0:
62+
# independent
63+
N = 100
64+
b = 10
65+
x, y = torch.randn((N, 50)), torch.randn((N, 30))
66+
elif request.param == 1:
67+
# linearly dependent
68+
N = 100
69+
b = 10
70+
x = torch.normal(1.0, 2.0, size=(N, 10))
71+
y = x @ torch.rand(10, 15) * 3 + torch.randn(N, 15) * 1e-4
72+
else:
73+
# non-linearly dependent
74+
N = 200
75+
b = 20
76+
x = torch.randn(N, 5)
77+
y = x @ torch.normal(0.0, torch.pi, size=(5, 3))
78+
y = (
79+
torch.stack([torch.sin(y[:, 0]), torch.cos(y[:, 1]), torch.exp(y[:, 2])], dim=1)
80+
+ torch.randn_like(y) * 1e-4
81+
)
82+
83+
return x, y, b
84+
85+
86+
@pytest.mark.parametrize("n_times", range(3))
87+
@pytest.mark.parametrize("sigma_x", [-1.0, 1.0])
88+
@pytest.mark.parametrize("sigma_y", [-1.0, 1.0])
89+
def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tensor, Tensor, int]):
90+
x, y, batch_size = test_case
91+
92+
hsic = HSIC(sigma_x=sigma_x, sigma_y=sigma_y)
93+
94+
hsic.reset()
95+
96+
np_hsic_sum = 0.0
97+
n_iters = y.shape[0] // batch_size
98+
for i in range(n_iters):
99+
idx = i * batch_size
100+
x_batch = x[idx : idx + batch_size]
101+
y_batch = y[idx : idx + batch_size]
102+
103+
hsic.update((x_batch, y_batch))
104+
np_hsic_sum += np_hsic(x_batch, y_batch, sigma_x, sigma_y)
105+
expected_hsic = np_hsic_sum / n_iters
106+
107+
assert isinstance(hsic.compute(), float)
108+
assert pytest.approx(expected_hsic, abs=2e-5) == hsic.compute()
109+
110+
111+
def test_accumulator_detached():
112+
hsic = HSIC()
113+
114+
x = torch.rand(10, 10, dtype=torch.float)
115+
y = torch.rand(10, 10, dtype=torch.float)
116+
hsic.update((x, y))
117+
118+
assert not hsic._sum_of_hsic.requires_grad
119+
120+
121+
@pytest.mark.usefixtures("distributed")
122+
class TestDistributed:
123+
@pytest.mark.parametrize("sigma_x", [-1.0, 1.0])
124+
@pytest.mark.parametrize("sigma_y", [-1.0, 1.0])
125+
def test_integration(self, sigma_x: float, sigma_y: float):
126+
tol = 2e-5
127+
n_iters = 100
128+
batch_size = 20
129+
n_dims_x = 100
130+
n_dims_y = 50
131+
132+
rank = idist.get_rank()
133+
torch.manual_seed(12 + rank)
134+
135+
device = idist.device()
136+
metric_devices = [torch.device("cpu")]
137+
if device.type != "xla":
138+
metric_devices.append(device)
139+
140+
for metric_device in metric_devices:
141+
x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device)
142+
143+
lin = nn.Linear(n_dims_x, n_dims_y).to(device)
144+
y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4
145+
146+
def data_loader(i, input_x, input_y):
147+
return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size]
148+
149+
engine = Engine(lambda e, i: data_loader(i, x, y))
150+
151+
m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device)
152+
m.attach(engine, "hsic")
153+
154+
data = list(range(n_iters))
155+
engine.run(data=data, max_epochs=1)
156+
157+
assert "hsic" in engine.state.metrics
158+
res = engine.state.metrics["hsic"]
159+
160+
x = idist.all_gather(x)
161+
y = idist.all_gather(y)
162+
total_n_iters = idist.all_reduce(n_iters)
163+
164+
np_res = 0.0
165+
for i in range(total_n_iters):
166+
x_batch, y_batch = data_loader(i, x, y)
167+
np_res += np_hsic(x_batch, y_batch, sigma_x, sigma_y)
168+
169+
expected_hsic = np_res / total_n_iters
170+
assert pytest.approx(expected_hsic, abs=tol) == res
171+
172+
def test_accumulator_device(self):
173+
device = idist.device()
174+
metric_devices = [torch.device("cpu")]
175+
if device.type != "xla":
176+
metric_devices.append(device)
177+
for metric_device in metric_devices:
178+
hsic = HSIC(device=metric_device)
179+
180+
for dev in (hsic._device, hsic._sum_of_hsic.device):
181+
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
182+
183+
x = torch.zeros(10, 10).float()
184+
y = torch.ones(10, 10).float()
185+
hsic.update((x, y))
186+
187+
for dev in (hsic._device, hsic._sum_of_hsic.device):
188+
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

0 commit comments

Comments
 (0)