|
| 1 | +from typing import Callable, Sequence, Union |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import Tensor |
| 5 | + |
| 6 | +from ignite.exceptions import NotComputableError |
| 7 | +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce |
| 8 | + |
| 9 | +__all__ = ["HSIC"] |
| 10 | + |
| 11 | + |
| 12 | +class HSIC(Metric): |
| 13 | + r"""Calculates the `Hilbert-Schmidt Independence Criterion (HSIC) |
| 14 | + <https://papers.nips.cc/paper_files/paper/2007/hash/d5cfead94f5350c12c322b5b664544c1-Abstract.html>`_. |
| 15 | +
|
| 16 | + .. math:: |
| 17 | + \text{HSIC}(X,Y) = \frac{1}{B(B-3)}\left[ \text{tr}(\tilde{\mathbf{K}}\tilde{\mathbf{L}}) |
| 18 | + + \frac{\mathbf{1}^\top \tilde{\mathbf{K}} \mathbf{11}^\top \tilde{\mathbf{L}} \mathbf{1}}{(B-1)(B-2)} |
| 19 | + -\frac{2}{B-2}\mathbf{1}^\top \tilde{\mathbf{K}}\tilde{\mathbf{L}} \mathbf{1} \right] |
| 20 | +
|
| 21 | + where :math:`B` is the batch size, and :math:`\tilde{\mathbf{K}}` |
| 22 | + and :math:`\tilde{\mathbf{L}}` are the Gram matrices of |
| 23 | + the Gaussian RBF kernel with their diagonal entries being set to zero. |
| 24 | +
|
| 25 | + HSIC measures non-linear statistical independence between features :math:`X` and :math:`Y`. |
| 26 | + HSIC becomes zero if and only if :math:`X` and :math:`Y` are independent. |
| 27 | +
|
| 28 | + This metric computes the unbiased estimator of HSIC proposed in |
| 29 | + `Song et al. (2012) <https://jmlr.csail.mit.edu/papers/v13/song12a.html>`_. |
| 30 | + The HSIC is estimated using Eq. (5) of the paper for each batch and the average is accumulated. |
| 31 | +
|
| 32 | + Each batch must contain at least four samples. |
| 33 | +
|
| 34 | + - ``update`` must receive output of the form ``(y_pred, y)``. |
| 35 | +
|
| 36 | + Args: |
| 37 | + sigma_x: bandwidth of the kernel for :math:`X`. |
| 38 | + If negative, a heuristic value determined by the median of the distances between |
| 39 | + the samples is used. Default: -1 |
| 40 | + sigma_y: bandwidth of the kernel for :math:`Y`. |
| 41 | + If negative, a heuristic value determined by the median of the distances |
| 42 | + between the samples is used. Default: -1 |
| 43 | + ignore_invalid_batch: If ``True``, computation for a batch with less than four samples is skipped. |
| 44 | + If ``False``, ``ValueError`` is raised when received such a batch. |
| 45 | + output_transform: a callable that is used to transform the |
| 46 | + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the |
| 47 | + form expected by the metric. This can be useful if, for example, you have a multi-output model and |
| 48 | + you want to compute the metric with respect to one of the outputs. |
| 49 | + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. |
| 50 | + device: specifies which device updates are accumulated on. Setting the |
| 51 | + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is |
| 52 | + non-blocking. By default, CPU. |
| 53 | + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be |
| 54 | + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` |
| 55 | + Alternatively, ``output_transform`` can be used to handle this. |
| 56 | +
|
| 57 | + Examples: |
| 58 | + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. |
| 59 | + The output of the engine's ``process_function`` needs to be in the format of |
| 60 | + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added |
| 61 | + to the metric to transform the output into the form expected by the metric. |
| 62 | +
|
| 63 | + ``y_pred`` and ``y`` should have the same shape. |
| 64 | +
|
| 65 | + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. |
| 66 | +
|
| 67 | + .. include:: defaults.rst |
| 68 | + :start-after: :orphan: |
| 69 | +
|
| 70 | + .. testcode:: |
| 71 | +
|
| 72 | + metric = HSIC() |
| 73 | + metric.attach(default_evaluator, "hsic") |
| 74 | + X = torch.tensor([[0., 1., 2., 3., 4.], |
| 75 | + [5., 6., 7., 8., 9.], |
| 76 | + [10., 11., 12., 13., 14.], |
| 77 | + [15., 16., 17., 18., 19.], |
| 78 | + [20., 21., 22., 23., 24.], |
| 79 | + [25., 26., 27., 28., 29.], |
| 80 | + [30., 31., 32., 33., 34.], |
| 81 | + [35., 36., 37., 38., 39.], |
| 82 | + [40., 41., 42., 43., 44.], |
| 83 | + [45., 46., 47., 48., 49.]]) |
| 84 | + Y = torch.sin(X * torch.pi * 2 / 50) |
| 85 | + state = default_evaluator.run([[X, Y]]) |
| 86 | + print(state.metrics["hsic"]) |
| 87 | +
|
| 88 | + .. testoutput:: |
| 89 | +
|
| 90 | + 0.09226646274328232 |
| 91 | +
|
| 92 | + .. versionadded:: 0.5.2 |
| 93 | + """ |
| 94 | + |
| 95 | + def __init__( |
| 96 | + self, |
| 97 | + sigma_x: float = -1, |
| 98 | + sigma_y: float = -1, |
| 99 | + ignore_invalid_batch: bool = True, |
| 100 | + output_transform: Callable = lambda x: x, |
| 101 | + device: Union[str, torch.device] = torch.device("cpu"), |
| 102 | + skip_unrolling: bool = False, |
| 103 | + ): |
| 104 | + super().__init__(output_transform, device, skip_unrolling=skip_unrolling) |
| 105 | + |
| 106 | + self.sigma_x = sigma_x |
| 107 | + self.sigma_y = sigma_y |
| 108 | + self.ignore_invalid_batch = ignore_invalid_batch |
| 109 | + |
| 110 | + _state_dict_all_req_keys = ("_sum_of_hsic", "_num_batches") |
| 111 | + |
| 112 | + @reinit__is_reduced |
| 113 | + def reset(self) -> None: |
| 114 | + self._sum_of_hsic = torch.tensor(0.0, device=self._device) |
| 115 | + self._num_batches = 0 |
| 116 | + |
| 117 | + @reinit__is_reduced |
| 118 | + def update(self, output: Sequence[Tensor]) -> None: |
| 119 | + X = output[0].detach().flatten(start_dim=1) |
| 120 | + Y = output[1].detach().flatten(start_dim=1) |
| 121 | + b = X.shape[0] |
| 122 | + |
| 123 | + if b <= 3: |
| 124 | + if self.ignore_invalid_batch: |
| 125 | + return |
| 126 | + else: |
| 127 | + raise ValueError(f"A batch must contain more than four samples, got only {b} samples.") |
| 128 | + |
| 129 | + mask = 1.0 - torch.eye(b, device=X.device) |
| 130 | + |
| 131 | + xx = X @ X.T |
| 132 | + rx = xx.diag().unsqueeze(0).expand_as(xx) |
| 133 | + dxx = rx.T + rx - xx * 2 |
| 134 | + |
| 135 | + vx: Union[Tensor, float] |
| 136 | + if self.sigma_x < 0: |
| 137 | + vx = torch.quantile(dxx, 0.5) |
| 138 | + else: |
| 139 | + vx = self.sigma_x**2 |
| 140 | + K = torch.exp(-0.5 * dxx / vx) * mask |
| 141 | + |
| 142 | + yy = Y @ Y.T |
| 143 | + ry = yy.diag().unsqueeze(0).expand_as(yy) |
| 144 | + dyy = ry.T + ry - yy * 2 |
| 145 | + |
| 146 | + vy: Union[Tensor, float] |
| 147 | + if self.sigma_y < 0: |
| 148 | + vy = torch.quantile(dyy, 0.5) |
| 149 | + else: |
| 150 | + vy = self.sigma_y**2 |
| 151 | + L = torch.exp(-0.5 * dyy / vy) * mask |
| 152 | + |
| 153 | + KL = K @ L |
| 154 | + trace = KL.trace() |
| 155 | + second_term = K.sum() * L.sum() / ((b - 1) * (b - 2)) |
| 156 | + third_term = KL.sum() / (b - 2) |
| 157 | + |
| 158 | + hsic = trace + second_term - third_term * 2.0 |
| 159 | + hsic /= b * (b - 3) |
| 160 | + hsic = torch.clamp(hsic, min=0.0) # HSIC must not be negative |
| 161 | + self._sum_of_hsic += hsic.to(self._device) |
| 162 | + |
| 163 | + self._num_batches += 1 |
| 164 | + |
| 165 | + @sync_all_reduce("_sum_of_hsic", "_num_batches") |
| 166 | + def compute(self) -> float: |
| 167 | + if self._num_batches == 0: |
| 168 | + raise NotComputableError("HSIC must have at least one batch before it can be computed.") |
| 169 | + |
| 170 | + return self._sum_of_hsic.item() / self._num_batches |
0 commit comments