Skip to content

Commit

Permalink
Implement Gaussian filtering (#1)
Browse files Browse the repository at this point in the history
* Implement Gaussian filtering

* Add test image loader

* Add Gaussian filtering tutorial

* Fix tests
  • Loading branch information
eigenvivek authored Aug 8, 2024
1 parent b471c10 commit 9bdfde4
Show file tree
Hide file tree
Showing 13 changed files with 621 additions and 26 deletions.
67 changes: 65 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# diptorch

diptorch
================

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

Expand All @@ -8,3 +8,66 @@
``` sh
pip install diptorch
```

## Hello, World!

``` python
import matplotlib.pyplot as plt

from diptorch.filters import gaussian_filter
from diptorch.utils import astronaut
```

``` python
# Zero-th order Gaussian filter (smoothing)
img = astronaut()
img_filtered = gaussian_filter(img, sigma=2.5)

plt.figure(figsize=(6, 3))
plt.subplot(121)
plt.imshow(img.squeeze(), cmap="gray")
plt.axis("off")
plt.subplot(122)
plt.imshow(img_filtered.squeeze(), cmap="gray")
plt.axis("off")
plt.tight_layout()
plt.show()
```

![](index_files/figure-commonmark/cell-3-output-1.png)

``` python
# First-order Gaussian filter
img = astronaut()
img_filtered = gaussian_filter(img, sigma=2.5, order=1)

plt.figure(figsize=(6, 3))
plt.subplot(121)
plt.imshow(img.squeeze(), cmap="gray")
plt.axis("off")
plt.subplot(122)
plt.imshow(img_filtered.squeeze(), cmap="gray")
plt.axis("off")
plt.tight_layout()
plt.show()
```

![](index_files/figure-commonmark/cell-4-output-1.png)

``` python
# Second-order Gaussian filter on the height dimension (y-axis)
img = astronaut()
img_filtered = gaussian_filter(img, sigma=2.5, order=[2, 0])

plt.figure(figsize=(6, 3))
plt.subplot(121)
plt.imshow(img.squeeze(), cmap="gray")
plt.axis("off")
plt.subplot(122)
plt.imshow(img_filtered.squeeze(), cmap="gray")
plt.axis("off")
plt.tight_layout()
plt.show()
```

![](index_files/figure-commonmark/cell-5-output-1.png)
8 changes: 7 additions & 1 deletion diptorch/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,10 @@
'doc_host': 'https://eigenvivek.github.io',
'git_url': 'https://github.com/eigenvivek/diptorch',
'lib_path': 'diptorch'},
'syms': {'diptorch.filter': {'diptorch.filter.foo': ('filter.html#foo', 'diptorch/filter.py')}}}
'syms': { 'diptorch.filters': { 'diptorch.filters._conv': ('filters.html#_conv', 'diptorch/filters.py'),
'diptorch.filters._gaussian_kernel_1d': ('filters.html#_gaussian_kernel_1d', 'diptorch/filters.py'),
'diptorch.filters._hessian_2d': ('filters.html#_hessian_2d', 'diptorch/filters.py'),
'diptorch.filters._hessian_3d': ('filters.html#_hessian_3d', 'diptorch/filters.py'),
'diptorch.filters.gaussian_filter': ('filters.html#gaussian_filter', 'diptorch/filters.py'),
'diptorch.filters.hessian': ('filters.html#hessian', 'diptorch/filters.py')},
'diptorch.utils': {'diptorch.utils.astronaut': ('utils.html#astronaut', 'diptorch/utils.py')}}}
8 changes: 0 additions & 8 deletions diptorch/filter.py

This file was deleted.

134 changes: 134 additions & 0 deletions diptorch/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/00_filters.ipynb.

# %% auto 0
__all__ = ['gaussian_filter', 'hessian']

# %% ../notebooks/00_filters.ipynb 3
from math import ceil, sqrt
from typing import Callable

import torch
import torch.nn.functional as F

# %% ../notebooks/00_filters.ipynb 5
def gaussian_filter(
img: torch.Tensor, # The input tensor
sigma: float, # Standard deviation for the Gaussian kernel
order: int | list = 0, # The order of the filter's derivative along each dim
mode: str = "reflect", # Padding mode for `torch.nn.functional.pad`
truncate: float = 4.0, # Number of standard deviations to sample the filter
) -> torch.Tensor:
"""
Convolves an image with a Gaussian kernel (or its derivatives).
Inspired by the API of `scipy.ndimage.gaussian_filter` and the
implementation of `diplib.Gauss`.
"""

# Specify the dimensions of the convolution to use
ndim = img.ndim - 2
if isinstance(order, int):
order = [order] * ndim
else:
assert len(order) == ndim, "Specify the Gaussian derivative order for each dim"
convfn = getattr(F, f"conv{ndim}d")

# Convolve along the rows, columns, and depth (optional)
for dim, derivative_order in enumerate(order):
img = _conv(img, convfn, sigma, derivative_order, truncate, mode, dim)
return img

# %% ../notebooks/00_filters.ipynb 6
def _gaussian_kernel_1d(
sigma: float, order: int, truncate: float, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
# Set the size of the kernel according to the sigma
radius = ceil(sigma * truncate)
x = torch.arange(-radius, radius + 1, dtype=dtype, device=device)

# Initialize the zeroth-order Gaussian kernel
var = sigma**2
g = (-x.pow(2) / (2 * var)).exp() / (sqrt(2 * torch.pi) * sigma)

# Optionally convert to a higher-order kernel
if order == 0:
return g
elif order == 1:
g1 = g * (-x / var)
g1 -= g1.mean()
g1 /= (g1 * x).sum() / -1 # Normalize the filter's impulse response to -1
return g1
elif order == 2:
g2 = g * (x.pow(2) / var - 1) / var
g2 -= g2.mean()
g2 /= (g2 * x.pow(2)).sum() / 2 # Normalize the filter's impulse response to 2
return g2
else:
raise NotImplementedError(f"Only supports order in [0, 1, 2], not {order}")


def _conv(
img: torch.Tensor,
convfn: Callable,
sigma: float,
order: int,
truncate: float,
mode: str,
dim: int,
):
# Make a 1D kernel and pad such that the image size remains the same
kernel = _gaussian_kernel_1d(sigma, order, truncate, img.dtype, img.device)
padding = len(kernel) // 2

# Specify the padding dimensions
pad = [0] * 2 * (img.ndim - 2)
for idx in range(2 * dim, 2 * dim + 2):
pad[idx] = padding
pad = pad[::-1]
x = F.pad(img, pad, mode=mode)

# Specify the dimension along which to do the convolution
view = [1] * img.ndim
view[dim + 2] *= -1

return convfn(x, weight=kernel.view(*view))

# %% ../notebooks/00_filters.ipynb 8
def hessian(img: torch.Tensor, sigma: float, **kwargs) -> torch.Tensor:
"""Compute the Hessian of a 2D or 3D image."""
if img.ndim == 4:
return _hessian_2d(img, sigma, **kwargs)
elif img.ndim == 5:
return _hessian_3d(img, sigma, **kwargs)
else:
raise ValueError(f"img can only be 2D or 3D, not {img.ndim-2}D")

# %% ../notebooks/00_filters.ipynb 9
def _hessian_2d(img: torch.Tensor, sigma: float, **kwargs):
xx = gaussian_filter(img, sigma, order=[0, 2], **kwargs).squeeze()
yy = gaussian_filter(img, sigma, order=[2, 0], **kwargs).squeeze()
xy = gaussian_filter(img, sigma, order=[1, 1], **kwargs).squeeze()
return torch.stack(
[
torch.stack([xx, xy], dim=-1),
torch.stack([xy, yy], dim=-1),
],
dim=-1,
)


def _hessian_3d(img: torch.Tensor, sigma: float, **kwargs):
xx = gaussian_filter(img, sigma, order=[0, 0, 2], **kwargs).squeeze()
yy = gaussian_filter(img, sigma, order=[0, 2, 0], **kwargs).squeeze()
zz = gaussian_filter(img, sigma, order=[2, 0, 0], **kwargs).squeeze()
xy = gaussian_filter(img, sigma, order=[0, 1, 1], **kwargs).squeeze()
xz = gaussian_filter(img, sigma, order=[1, 0, 1], **kwargs).squeeze()
yz = gaussian_filter(img, sigma, order=[1, 1, 0], **kwargs).squeeze()
return torch.stack(
[
torch.stack([xx, xy, xz], dim=-1),
torch.stack([xy, yy, yz], dim=-1),
torch.stack([xz, yz, zz], dim=-1),
],
dim=-1,
)
18 changes: 18 additions & 0 deletions diptorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/02_utils.ipynb.

# %% auto 0
__all__ = ['astronaut']

# %% ../notebooks/02_utils.ipynb 3
import torch
from skimage import img_as_float
from skimage.color import rgb2gray
from skimage.data import astronaut as _astronaut

# %% ../notebooks/02_utils.ipynb 4
def astronaut(dtype: torch.dtype = torch.float32):
img = _astronaut()
img = img_as_float(img)
img = rgb2gray(img)
img = torch.from_numpy(img).to(dtype)
return img[None, None]
4 changes: 3 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ channels:
- nvidia
dependencies:
- pip
- pytorch
- pytorch
- matplotlib
- scikit-image
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 9bdfde4

Please sign in to comment.