-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement Gaussian filtering * Add test image loader * Add Gaussian filtering tutorial * Fix tests
- Loading branch information
1 parent
b471c10
commit 9bdfde4
Showing
13 changed files
with
621 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/00_filters.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['gaussian_filter', 'hessian'] | ||
|
||
# %% ../notebooks/00_filters.ipynb 3 | ||
from math import ceil, sqrt | ||
from typing import Callable | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
# %% ../notebooks/00_filters.ipynb 5 | ||
def gaussian_filter( | ||
img: torch.Tensor, # The input tensor | ||
sigma: float, # Standard deviation for the Gaussian kernel | ||
order: int | list = 0, # The order of the filter's derivative along each dim | ||
mode: str = "reflect", # Padding mode for `torch.nn.functional.pad` | ||
truncate: float = 4.0, # Number of standard deviations to sample the filter | ||
) -> torch.Tensor: | ||
""" | ||
Convolves an image with a Gaussian kernel (or its derivatives). | ||
Inspired by the API of `scipy.ndimage.gaussian_filter` and the | ||
implementation of `diplib.Gauss`. | ||
""" | ||
|
||
# Specify the dimensions of the convolution to use | ||
ndim = img.ndim - 2 | ||
if isinstance(order, int): | ||
order = [order] * ndim | ||
else: | ||
assert len(order) == ndim, "Specify the Gaussian derivative order for each dim" | ||
convfn = getattr(F, f"conv{ndim}d") | ||
|
||
# Convolve along the rows, columns, and depth (optional) | ||
for dim, derivative_order in enumerate(order): | ||
img = _conv(img, convfn, sigma, derivative_order, truncate, mode, dim) | ||
return img | ||
|
||
# %% ../notebooks/00_filters.ipynb 6 | ||
def _gaussian_kernel_1d( | ||
sigma: float, order: int, truncate: float, dtype: torch.dtype, device: torch.device | ||
) -> torch.Tensor: | ||
# Set the size of the kernel according to the sigma | ||
radius = ceil(sigma * truncate) | ||
x = torch.arange(-radius, radius + 1, dtype=dtype, device=device) | ||
|
||
# Initialize the zeroth-order Gaussian kernel | ||
var = sigma**2 | ||
g = (-x.pow(2) / (2 * var)).exp() / (sqrt(2 * torch.pi) * sigma) | ||
|
||
# Optionally convert to a higher-order kernel | ||
if order == 0: | ||
return g | ||
elif order == 1: | ||
g1 = g * (-x / var) | ||
g1 -= g1.mean() | ||
g1 /= (g1 * x).sum() / -1 # Normalize the filter's impulse response to -1 | ||
return g1 | ||
elif order == 2: | ||
g2 = g * (x.pow(2) / var - 1) / var | ||
g2 -= g2.mean() | ||
g2 /= (g2 * x.pow(2)).sum() / 2 # Normalize the filter's impulse response to 2 | ||
return g2 | ||
else: | ||
raise NotImplementedError(f"Only supports order in [0, 1, 2], not {order}") | ||
|
||
|
||
def _conv( | ||
img: torch.Tensor, | ||
convfn: Callable, | ||
sigma: float, | ||
order: int, | ||
truncate: float, | ||
mode: str, | ||
dim: int, | ||
): | ||
# Make a 1D kernel and pad such that the image size remains the same | ||
kernel = _gaussian_kernel_1d(sigma, order, truncate, img.dtype, img.device) | ||
padding = len(kernel) // 2 | ||
|
||
# Specify the padding dimensions | ||
pad = [0] * 2 * (img.ndim - 2) | ||
for idx in range(2 * dim, 2 * dim + 2): | ||
pad[idx] = padding | ||
pad = pad[::-1] | ||
x = F.pad(img, pad, mode=mode) | ||
|
||
# Specify the dimension along which to do the convolution | ||
view = [1] * img.ndim | ||
view[dim + 2] *= -1 | ||
|
||
return convfn(x, weight=kernel.view(*view)) | ||
|
||
# %% ../notebooks/00_filters.ipynb 8 | ||
def hessian(img: torch.Tensor, sigma: float, **kwargs) -> torch.Tensor: | ||
"""Compute the Hessian of a 2D or 3D image.""" | ||
if img.ndim == 4: | ||
return _hessian_2d(img, sigma, **kwargs) | ||
elif img.ndim == 5: | ||
return _hessian_3d(img, sigma, **kwargs) | ||
else: | ||
raise ValueError(f"img can only be 2D or 3D, not {img.ndim-2}D") | ||
|
||
# %% ../notebooks/00_filters.ipynb 9 | ||
def _hessian_2d(img: torch.Tensor, sigma: float, **kwargs): | ||
xx = gaussian_filter(img, sigma, order=[0, 2], **kwargs).squeeze() | ||
yy = gaussian_filter(img, sigma, order=[2, 0], **kwargs).squeeze() | ||
xy = gaussian_filter(img, sigma, order=[1, 1], **kwargs).squeeze() | ||
return torch.stack( | ||
[ | ||
torch.stack([xx, xy], dim=-1), | ||
torch.stack([xy, yy], dim=-1), | ||
], | ||
dim=-1, | ||
) | ||
|
||
|
||
def _hessian_3d(img: torch.Tensor, sigma: float, **kwargs): | ||
xx = gaussian_filter(img, sigma, order=[0, 0, 2], **kwargs).squeeze() | ||
yy = gaussian_filter(img, sigma, order=[0, 2, 0], **kwargs).squeeze() | ||
zz = gaussian_filter(img, sigma, order=[2, 0, 0], **kwargs).squeeze() | ||
xy = gaussian_filter(img, sigma, order=[0, 1, 1], **kwargs).squeeze() | ||
xz = gaussian_filter(img, sigma, order=[1, 0, 1], **kwargs).squeeze() | ||
yz = gaussian_filter(img, sigma, order=[1, 1, 0], **kwargs).squeeze() | ||
return torch.stack( | ||
[ | ||
torch.stack([xx, xy, xz], dim=-1), | ||
torch.stack([xy, yy, yz], dim=-1), | ||
torch.stack([xz, yz, zz], dim=-1), | ||
], | ||
dim=-1, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/02_utils.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['astronaut'] | ||
|
||
# %% ../notebooks/02_utils.ipynb 3 | ||
import torch | ||
from skimage import img_as_float | ||
from skimage.color import rgb2gray | ||
from skimage.data import astronaut as _astronaut | ||
|
||
# %% ../notebooks/02_utils.ipynb 4 | ||
def astronaut(dtype: torch.dtype = torch.float32): | ||
img = _astronaut() | ||
img = img_as_float(img) | ||
img = rgb2gray(img) | ||
img = torch.from_numpy(img).to(dtype) | ||
return img[None, None] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,6 @@ channels: | |
- nvidia | ||
dependencies: | ||
- pip | ||
- pytorch | ||
- pytorch | ||
- matplotlib | ||
- scikit-image |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.