Skip to content

Commit

Permalink
Implement Hermitian-2 and -3 eigenvalue solvers (#2)
Browse files Browse the repository at this point in the history
* Implement Hermitian-2 eigenvalue solver

* Add matrix wrapper for closed-form solvers

* Fix bug in 2x2 determinant

* Implement Hermitian-3 eigenvalue solver

* Return Hessian submatrices by default

* Add docstrings and typehints

* Add function to directly compute eigenvalues of the Hessian matrix

* Add examples for the Hessian
  • Loading branch information
eigenvivek authored Aug 8, 2024
1 parent 9bdfde4 commit 49e72f5
Show file tree
Hide file tree
Showing 11 changed files with 571 additions and 59 deletions.
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,47 @@ plt.show()
```

![](index_files/figure-commonmark/cell-5-output-1.png)

## Hessian matrix

``` python
from diptorch.filters import hessian, hessian_eigenvalues
from einops import rearrange
```

``` python
# Hessian matrix of an image (all second-order partial derivatives)
img = astronaut()

H = hessian(img, sigma=2.5, as_matrix=True)
H = rearrange(H, "B C1 C2 H W -> B (C1 H) (C2 W)")

plt.imshow(H.squeeze(), cmap="gray")
plt.axis("off")
plt.show()
```

![](index_files/figure-commonmark/cell-7-output-1.png)

``` python
# Eigenvalues of the Hessian matrix of an image
img = astronaut()
eig = hessian_eigenvalues(img, sigma=2.5)

plt.figure(figsize=(9, 3))
plt.subplot(131)
plt.imshow(img.squeeze(), cmap="gray")
plt.axis("off")
plt.subplot(132)
plt.imshow(eig.squeeze()[0], cmap="gray")
plt.title("Smallest eigenvalue")
plt.axis("off")
plt.subplot(133)
plt.imshow(eig.squeeze()[1], cmap="gray")
plt.title("Largest eigenvalue")
plt.axis("off")
plt.tight_layout()
plt.show()
```

![](index_files/figure-commonmark/cell-8-output-1.png)
10 changes: 9 additions & 1 deletion diptorch/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
'diptorch.filters._gaussian_kernel_1d': ('filters.html#_gaussian_kernel_1d', 'diptorch/filters.py'),
'diptorch.filters._hessian_2d': ('filters.html#_hessian_2d', 'diptorch/filters.py'),
'diptorch.filters._hessian_3d': ('filters.html#_hessian_3d', 'diptorch/filters.py'),
'diptorch.filters._hessian_as_matrix': ('filters.html#_hessian_as_matrix', 'diptorch/filters.py'),
'diptorch.filters.gaussian_filter': ('filters.html#gaussian_filter', 'diptorch/filters.py'),
'diptorch.filters.hessian': ('filters.html#hessian', 'diptorch/filters.py')},
'diptorch.filters.hessian': ('filters.html#hessian', 'diptorch/filters.py'),
'diptorch.filters.hessian_eigenvalues': ('filters.html#hessian_eigenvalues', 'diptorch/filters.py')},
'diptorch.linalg': { 'diptorch.linalg._is_hermitian': ('linalg.html#_is_hermitian', 'diptorch/linalg.py'),
'diptorch.linalg._is_square': ('linalg.html#_is_square', 'diptorch/linalg.py'),
'diptorch.linalg.deth3': ('linalg.html#deth3', 'diptorch/linalg.py'),
'diptorch.linalg.eigvalsh': ('linalg.html#eigvalsh', 'diptorch/linalg.py'),
'diptorch.linalg.eigvalsh2': ('linalg.html#eigvalsh2', 'diptorch/linalg.py'),
'diptorch.linalg.eigvalsh3': ('linalg.html#eigvalsh3', 'diptorch/linalg.py')},
'diptorch.utils': {'diptorch.utils.astronaut': ('utils.html#astronaut', 'diptorch/utils.py')}}}
87 changes: 59 additions & 28 deletions diptorch/filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/00_filters.ipynb.

# %% auto 0
__all__ = ['gaussian_filter', 'hessian']
__all__ = ['gaussian_filter', 'hessian', 'hessian_eigenvalues']

# %% ../notebooks/00_filters.ipynb 3
from math import ceil, sqrt
Expand Down Expand Up @@ -94,41 +94,72 @@ def _conv(
return convfn(x, weight=kernel.view(*view))

# %% ../notebooks/00_filters.ipynb 8
def hessian(img: torch.Tensor, sigma: float, **kwargs) -> torch.Tensor:
from .linalg import eigvalsh2, eigvalsh3


def hessian(
img: torch.Tensor, sigma: float, as_matrix: bool = False, **kwargs
) -> torch.Tensor:
"""Compute the Hessian of a 2D or 3D image."""
if img.ndim == 4:
return _hessian_2d(img, sigma, **kwargs)
hessian = _hessian_2d(img, sigma, **kwargs)
elif img.ndim == 5:
return _hessian_3d(img, sigma, **kwargs)
hessian = _hessian_3d(img, sigma, **kwargs)
else:
raise ValueError(f"img can only be 2D or 3D, not {img.ndim-2}D")

if as_matrix:
return _hessian_as_matrix(*hessian)
else:
return hessian


def hessian_eigenvalues(img: torch.Tensor, sigma: float, **kwargs):
H = hessian(img, sigma, **kwargs)
if len(H) == 3:
return eigvalsh2(*H)
elif len(H) == 6:
return eigvalsh3(*H)
else:
raise ValueError(f"Unrecognized number of upper triangular elements: {len(H)}")

# %% ../notebooks/00_filters.ipynb 9
def _hessian_2d(img: torch.Tensor, sigma: float, **kwargs):
xx = gaussian_filter(img, sigma, order=[0, 2], **kwargs).squeeze()
yy = gaussian_filter(img, sigma, order=[2, 0], **kwargs).squeeze()
xy = gaussian_filter(img, sigma, order=[1, 1], **kwargs).squeeze()
return torch.stack(
[
torch.stack([xx, xy], dim=-1),
torch.stack([xy, yy], dim=-1),
],
dim=-1,
)
xx = gaussian_filter(img, sigma, order=[0, 2], **kwargs)
yy = gaussian_filter(img, sigma, order=[2, 0], **kwargs)
xy = gaussian_filter(img, sigma, order=[1, 1], **kwargs)
return xx, xy, yy


def _hessian_3d(img: torch.Tensor, sigma: float, **kwargs):
xx = gaussian_filter(img, sigma, order=[0, 0, 2], **kwargs).squeeze()
yy = gaussian_filter(img, sigma, order=[0, 2, 0], **kwargs).squeeze()
zz = gaussian_filter(img, sigma, order=[2, 0, 0], **kwargs).squeeze()
xy = gaussian_filter(img, sigma, order=[0, 1, 1], **kwargs).squeeze()
xz = gaussian_filter(img, sigma, order=[1, 0, 1], **kwargs).squeeze()
yz = gaussian_filter(img, sigma, order=[1, 1, 0], **kwargs).squeeze()
return torch.stack(
[
torch.stack([xx, xy, xz], dim=-1),
torch.stack([xy, yy, yz], dim=-1),
torch.stack([xz, yz, zz], dim=-1),
],
dim=-1,
)
xx = gaussian_filter(img, sigma, order=[0, 0, 2], **kwargs)
yy = gaussian_filter(img, sigma, order=[0, 2, 0], **kwargs)
zz = gaussian_filter(img, sigma, order=[2, 0, 0], **kwargs)
xy = gaussian_filter(img, sigma, order=[0, 1, 1], **kwargs)
xz = gaussian_filter(img, sigma, order=[1, 0, 1], **kwargs)
yz = gaussian_filter(img, sigma, order=[1, 1, 0], **kwargs)
return xx, xy, xz, yy, yz, zz


def _hessian_as_matrix(*args):
if len(args) == 3:
xx, xy, yy = args
return torch.stack(
[
torch.concat([xx, xy], dim=1),
torch.concat([xy, yy], dim=1),
],
dim=1,
)
elif len(args) == 6:
xx, xy, xz, yy, yz, zz = args
return torch.stack(
[
torch.concat([xx, xy, xz], dim=1),
torch.concat([xy, yy, yz], dim=1),
torch.concat([xz, yz, zz], dim=1),
],
dim=1,
)
else:
raise ValueError(f"Invalid number of arguments: {len(args)}")
98 changes: 98 additions & 0 deletions diptorch/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/01_linalg.ipynb.

# %% auto 0
__all__ = ['eigvalsh', 'eigvalsh2', 'eigvalsh3']

# %% ../notebooks/01_linalg.ipynb 3
import torch

# %% ../notebooks/01_linalg.ipynb 5
def _is_square(A: torch.Tensor) -> bool:
_, i, j, *_ = A.shape
assert i == j, "Matrix is not square"


def _is_hermitian(A: torch.Tensor) -> bool:
return torch.testing.assert_close(
A, A.transpose(1, 2).conj(), msg="Matrix is not Hermitian"
)

# %% ../notebooks/01_linalg.ipynb 6
def eigvalsh(A: torch.Tensor, check_valid: bool = True) -> torch.Tensor:
"""
Compute the eigenvalues of a batched tensor with shape [B C C H W (D)]
where C is 2 or 3, and the tensor is Hermitian in dimensions 1 and 2.
Returns eigenvalues in a tensor with shape [1 2 H W] or [1 3 H W D],
for 2D and 3D inputs, respectively, sorted in ascending order.
"""
if check_valid:
_is_square(A)
_is_hermitian(A)
if A.shape[1] == 2:
return eigvalsh2(*A[:, *torch.triu_indices(2, 2)].split(1, dim=1))
elif A.shape[1] == 3:
return eigvalsh3(*A[:, *torch.triu_indices(3, 3)].split(1, dim=1))
else:
raise ValueError("Only supports 2×2 and 3×3 matrices")

# %% ../notebooks/01_linalg.ipynb 7
def eigvalsh2(ii: torch.Tensor, ij: torch.Tensor, jj: torch.Tensor) -> torch.Tensor:
"""
Compute the eigenvalues of a batched Hermitian 2×2 tensor
where blocks have shape [1 1 H W].
Returns eigenvalues in a tensor with shape [1 2 H W]
sorted in ascending order.
"""
tr = ii + jj
det = ii * jj - ij.square()

disc = (tr.square() - 4 * det).sqrt()
disc = torch.concat([-disc, disc], dim=1)

eigvals = (tr + disc) / 2
return eigvals

# %% ../notebooks/01_linalg.ipynb 8
def eigvalsh3(
ii: torch.Tensor,
ij: torch.Tensor,
ik: torch.Tensor,
jj: torch.Tensor,
jk: torch.Tensor,
kk: torch.Tensor,
) -> torch.Tensor:
"""
Compute the eigenvalues of a batched Hermitian 3×3 tensor
where blocks have shape [1 1 H W D].
Returns eigenvalues in a tensor with shape [1 3 H W D]
sorted in ascending order.
"""
q = (ii + jj + kk) / 3
p1 = torch.concat([ij, ik, jk], dim=1).square().sum(1, keepdim=True)
p2 = (torch.concat([ii, jj, kk], dim=1) - q).square().sum(
dim=1, keepdim=True
) + 2 * p1
p = (p2 / 6).sqrt()

r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / p.pow(3) / 2
r = r.clamp(-1, 1)
phi = r.arccos() / 3

eig3 = q + 2 * p * phi.cos()
eig1 = q + 2 * p * (phi + 2 * torch.pi / 3).cos()
eig2 = 3 * q - eig1 - eig3
print(eig1.shape, eig2.shape, eig3.shape)
return torch.concat([eig1, eig2, eig3], dim=1)

# %% ../notebooks/01_linalg.ipynb 9
def deth3(ii, ij, ik, jj, jk, kk):
return (
ii * jj * kk
+ 2 * ij * ik * jk
- ii * jk.square()
- jj * ik.square()
- kk * ij.square()
)
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ channels:
dependencies:
- pip
- pytorch
- einops
- matplotlib
- scikit-image
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
91 changes: 62 additions & 29 deletions notebooks/00_filters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hessian matrix of an image"
"## Hessian matrix of an image\n",
"\n",
"Compute a symmetric matrix of all second-order partial Gaussian derivatives of an image."
]
},
{
Expand All @@ -162,14 +164,34 @@
"outputs": [],
"source": [
"#| export\n",
"def hessian(img: torch.Tensor, sigma: float, **kwargs) -> torch.Tensor:\n",
"from diptorch.linalg import eigvalsh2, eigvalsh3\n",
"\n",
"\n",
"def hessian(\n",
" img: torch.Tensor, sigma: float, as_matrix: bool = False, **kwargs\n",
") -> torch.Tensor:\n",
" \"\"\"Compute the Hessian of a 2D or 3D image.\"\"\"\n",
" if img.ndim == 4:\n",
" return _hessian_2d(img, sigma, **kwargs)\n",
" hessian = _hessian_2d(img, sigma, **kwargs)\n",
" elif img.ndim == 5:\n",
" return _hessian_3d(img, sigma, **kwargs)\n",
" hessian = _hessian_3d(img, sigma, **kwargs)\n",
" else:\n",
" raise ValueError(f\"img can only be 2D or 3D, not {img.ndim-2}D\")\n",
"\n",
" if as_matrix:\n",
" return _hessian_as_matrix(*hessian)\n",
" else:\n",
" raise ValueError(f\"img can only be 2D or 3D, not {img.ndim-2}D\")"
" return hessian\n",
"\n",
"\n",
"def hessian_eigenvalues(img: torch.Tensor, sigma: float, **kwargs):\n",
" H = hessian(img, sigma, **kwargs)\n",
" if len(H) == 3:\n",
" return eigvalsh2(*H)\n",
" elif len(H) == 6:\n",
" return eigvalsh3(*H)\n",
" else:\n",
" raise ValueError(f\"Unrecognized number of upper triangular elements: {len(H)}\")"
]
},
{
Expand All @@ -180,33 +202,44 @@
"source": [
"#| exporti\n",
"def _hessian_2d(img: torch.Tensor, sigma: float, **kwargs):\n",
" xx = gaussian_filter(img, sigma, order=[0, 2], **kwargs).squeeze()\n",
" yy = gaussian_filter(img, sigma, order=[2, 0], **kwargs).squeeze()\n",
" xy = gaussian_filter(img, sigma, order=[1, 1], **kwargs).squeeze()\n",
" return torch.stack(\n",
" [\n",
" torch.stack([xx, xy], dim=-1),\n",
" torch.stack([xy, yy], dim=-1),\n",
" ],\n",
" dim=-1,\n",
" )\n",
" xx = gaussian_filter(img, sigma, order=[0, 2], **kwargs)\n",
" yy = gaussian_filter(img, sigma, order=[2, 0], **kwargs)\n",
" xy = gaussian_filter(img, sigma, order=[1, 1], **kwargs)\n",
" return xx, xy, yy\n",
"\n",
"\n",
"def _hessian_3d(img: torch.Tensor, sigma: float, **kwargs):\n",
" xx = gaussian_filter(img, sigma, order=[0, 0, 2], **kwargs).squeeze()\n",
" yy = gaussian_filter(img, sigma, order=[0, 2, 0], **kwargs).squeeze()\n",
" zz = gaussian_filter(img, sigma, order=[2, 0, 0], **kwargs).squeeze()\n",
" xy = gaussian_filter(img, sigma, order=[0, 1, 1], **kwargs).squeeze()\n",
" xz = gaussian_filter(img, sigma, order=[1, 0, 1], **kwargs).squeeze()\n",
" yz = gaussian_filter(img, sigma, order=[1, 1, 0], **kwargs).squeeze()\n",
" return torch.stack(\n",
" [\n",
" torch.stack([xx, xy, xz], dim=-1),\n",
" torch.stack([xy, yy, yz], dim=-1),\n",
" torch.stack([xz, yz, zz], dim=-1),\n",
" ],\n",
" dim=-1,\n",
" )"
" xx = gaussian_filter(img, sigma, order=[0, 0, 2], **kwargs)\n",
" yy = gaussian_filter(img, sigma, order=[0, 2, 0], **kwargs)\n",
" zz = gaussian_filter(img, sigma, order=[2, 0, 0], **kwargs)\n",
" xy = gaussian_filter(img, sigma, order=[0, 1, 1], **kwargs)\n",
" xz = gaussian_filter(img, sigma, order=[1, 0, 1], **kwargs)\n",
" yz = gaussian_filter(img, sigma, order=[1, 1, 0], **kwargs)\n",
" return xx, xy, xz, yy, yz, zz\n",
"\n",
"\n",
"def _hessian_as_matrix(*args):\n",
" if len(args) == 3:\n",
" xx, xy, yy = args\n",
" return torch.stack(\n",
" [\n",
" torch.concat([xx, xy], dim=1),\n",
" torch.concat([xy, yy], dim=1),\n",
" ],\n",
" dim=1,\n",
" )\n",
" elif len(args) == 6:\n",
" xx, xy, xz, yy, yz, zz = args\n",
" return torch.stack(\n",
" [\n",
" torch.concat([xx, xy, xz], dim=1),\n",
" torch.concat([xy, yy, yz], dim=1),\n",
" torch.concat([xz, yz, zz], dim=1),\n",
" ],\n",
" dim=1,\n",
" )\n",
" else:\n",
" raise ValueError(f\"Invalid number of arguments: {len(args)}\")"
]
},
{
Expand Down
Loading

0 comments on commit 49e72f5

Please sign in to comment.