Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 128 additions & 61 deletions python/lsst/scarlet/lite/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from typing import Sequence

import numpy as np
from lsst.scarlet.lite.detect_pybind11 import Footprint, get_footprints # type: ignore
from dataclasses import dataclass
from lsst.scarlet.lite.detect_pybind11 import Footprint, get_footprints, Peak # type: ignore

from .bbox import Box, overlapped_slices
from .image import Image
from .utils import continue_class
from .wavelet import (

Check failure on line 34 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

F401

'.wavelet.multiband_starlet_reconstruction' imported but unused

Check failure on line 34 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint_with_flake8

F401

'.wavelet.multiband_starlet_reconstruction' imported but unused
get_multiresolution_support,
get_starlet_scales,
multiband_starlet_reconstruction,
Expand Down Expand Up @@ -246,93 +247,159 @@
return (support.support * _coeffs).astype(images.dtype)


def get_support(
image: np.ndarray,
sigma: float | None = None,
epsilon: float = 1e-1
):
if sigma is None:
sigma = np.median(np.absolute(image - np.median(image)))
last_sigma = sigma
for _ in range(20):
m = np.abs(image) > 5 * sigma
s = ~m
sigma = np.std(image*s.astype(int))

Check failure on line 261 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E226

missing whitespace around arithmetic operator

Check failure on line 261 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint_with_flake8

E226

missing whitespace around arithmetic operator
if np.abs(sigma - last_sigma)/sigma < epsilon:

Check failure on line 262 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E226

missing whitespace around arithmetic operator

Check failure on line 262 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint_with_flake8

E226

missing whitespace around arithmetic operator
break
return m, sigma


@dataclass
class DetectionResult:
image_sigmas: list[float]
detection: np.ndarray
starlets: np.ndarray
peak_footprints: list[Footprint]
starlet_sigma: float
starlet_support: np.ndarray
detection_sigma: float
detection_support: np.ndarray
footprints: list[Footprint]
peaks: list[Peak]
dropped_peaks: list[Peak]


def detect_footprints(
images: np.ndarray,
variance: np.ndarray,
scales: int = 1,
detection: np.ndarray | None = None,
starlet_scale: int = 1,
generation: int = 2,
origin: tuple[int, int] | None = None,
min_separation: float = 4,
min_area: int = 4,
peak_thresh: float = 5,
footprint_thresh: float = 5,
find_peaks: bool = True,
remove_high_freq: bool = True,
min_pixel_detect: int = 1,
) -> list[Footprint]:
min_starlet_area: int = 4,
peak_thresh: float = 3,
footprint_thresh: float = 2,
min_footprint_area: int = 8,
) -> DetectionResult:
"""Detect footprints in an image

Parameters
----------
images:
The array of images with shape `(bands, Ny, Nx)` for which to
calculate wavelet coefficients.
variance:
An array of variances with the same shape as `images`.
scales:
The maximum number of wavelet scales to use.
If `remove_high_freq` is `False`, then this argument is ignored.
detection:
An optional detection image. If `None` then one is created.
starlet_scale:
The scale of the starlet transform to use for detection.
All of the default configs are tuned for `starlet_scale=1`.
If using `starlet_scale=2` then it is recommended to use:
- `min_starlet_area=1`
- `peak_thresh=2`
generation:
The generation of the starlet transform to use.
If `remove_high_freq` is `False`, then this argument is ignored.
origin:
The location (y, x) of the lower corner of the image.
min_separation:
The minimum separation between peaks in pixels.
min_area:
The minimum area of a footprint in pixels.
min_starlet_area:
The minimum area of a footprint in starlet space in pixels.
peak_thresh:
The threshold for peak detection.
footprint_thresh:
The threshold for footprint detection.
find_peaks:
If `True`, then detect peaks in the detection image,
otherwise only the footprints are returned.
remove_high_freq:
If `True`, then remove high frequency wavelet coefficients
before detecting peaks.
min_pixel_detect:
The minimum number of bands that must be above the
detection threshold for a pixel to be included in a footprint.
min_footprint_area:
The minimum area of a footprint in the image in pixels.
"""

if origin is None:
origin = (0, 0)
if remove_high_freq:
# Build the wavelet coefficients
wavelets = get_wavelets(
images,
variance,
scales=scales,
generation=generation,
)
# Remove the high frequency wavelets.
# This has the effect of preventing high frequency noise
# from interfering with the detection of peak positions.
wavelets[0] = 0
# Reconstruct the image from the remaining wavelet coefficients
_images = multiband_starlet_reconstruction(
wavelets,
generation=generation,
)
else:
_images = images
# Build a SNR weighted detection image
sigma = np.median(np.sqrt(variance), axis=(1, 2)) / 2
detection = np.sum(_images / sigma[:, None, None], axis=0)
if min_pixel_detect > 1:
mask = np.sum(images > 0, axis=0) >= min_pixel_detect
detection[~mask] = 0
y0, x0 = origin

if detection is None:
# Find the standard deviation of the noise in each band
sigmas = []
for image in images:
support, sigma = get_support(image)
sigmas.append(sigma)

# Create the variance weighted detection image
detection = np.sum([image/sigma for image, sigma in zip(images, sigmas)], axis=0)

Check failure on line 337 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E226

missing whitespace around arithmetic operator

Check failure on line 337 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint_with_flake8

E226

missing whitespace around arithmetic operator

# Use the chosen scale of starlets to act as a compensated filter
# for detection
starlets = starlet_transform(detection, scales=starlet_scale+1, generation=generation)

Check failure on line 341 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E226

missing whitespace around arithmetic operator

Check failure on line 341 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint_with_flake8

E226

missing whitespace around arithmetic operator

# Estimate the noise in the detection image
starlet_sigma = np.median(np.absolute(detection - np.median(detection)))
starlet_support = get_multiresolution_support(
image=detection,
starlets=starlets,
sigma=starlet_sigma,
sigma_scaling=3,
)

# Detect peaks on the detection image
footprints = get_footprints(
detection,
peak_footprints = get_footprints(
starlets[starlet_scale],
min_separation,
min_area,
peak_thresh,
footprint_thresh,
find_peaks,
origin[0],
origin[1],
min_starlet_area,
peak_thresh*starlet_support.sigma[starlet_scale],

Check failure on line 357 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E226

missing whitespace around arithmetic operator

Check failure on line 357 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint_with_flake8

E226

missing whitespace around arithmetic operator
footprint_thresh*starlet_support.sigma[starlet_scale],

Check failure on line 358 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E226

missing whitespace around arithmetic operator

Check failure on line 358 in python/lsst/scarlet/lite/detect.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint_with_flake8

E226

missing whitespace around arithmetic operator
True,
y0,
x0,
)

return footprints
# Extract the peaks from the detection footprints
peaks = [peak for fp in peak_footprints for peak in fp.peaks]

# Create a new set of footprints on the images.
# Detection is made on the boolean image, so the parameters
# that we pass to it are simplified.
support, sigma = get_support(detection)
footprints = get_footprints(
detection > sigma,
0,
min_footprint_area,
0,
0,
False,
y0,
x0,
)

# Create an image of all of the footprints
footprint_image = footprints_to_image(footprints, Box(detection.shape, origin=origin))
dropped_peaks = []
for peak in peaks:
footprint_index = footprint_image.at(peak.y, peak.x) - 1
if footprint_index >= 0:
footprints[footprint_index].add_peak(peak)
else:
dropped_peaks.append(peak)
logger.warning(f"Peak at ({peak.y}, {peak.x}) not in footprint")

return DetectionResult(
image_sigmas=sigmas,
detection=detection,
starlets=starlets,
peak_footprints=peak_footprints,
starlet_sigma=starlet_sigma,
starlet_support=starlet_support,
detection_sigma=sigma,
detection_support=support,
footprints=footprints,
peaks=peaks,
dropped_peaks=dropped_peaks,
)
21 changes: 18 additions & 3 deletions python/lsst/scarlet/lite/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def starlet_reconstruction(
starlets: np.ndarray,
generation: int = 2,
convolve2d: Callable | None = None,
skip_scales: list[int] | None = None,
) -> np.ndarray:
"""Reconstruct an image from a dictionary of starlets

Expand All @@ -197,6 +198,9 @@ def starlet_reconstruction(
convolve2d:
The filter function to use to convolve the image
with starlets in 2D.
skip_scales:
List of scales to skip in the reconstruction.
This can be used to remove noise at small scales.

Returns
-------
Expand All @@ -207,11 +211,15 @@ def starlet_reconstruction(
return np.sum(starlets, axis=0)
if convolve2d is None:
convolve2d = bspline_convolve
if skip_scales is None:
skip_scales = []
scales = len(starlets) - 1

c = starlets[-1]
for i in range(1, scales + 1):
j = scales - i
if j in skip_scales:
continue
cj = convolve2d(c, j)
c = cj + starlets[j]
return c
Expand All @@ -221,6 +229,7 @@ def multiband_starlet_reconstruction(
starlets: np.ndarray,
generation: int = 2,
convolve2d: Callable | None = None,
skip_scales: list[int] | None = None,
) -> np.ndarray:
"""Reconstruct a multiband image.

Expand All @@ -230,7 +239,12 @@ def multiband_starlet_reconstruction(
_, bands, width, height = starlets.shape
result = np.zeros((bands, width, height), dtype=starlets.dtype)
for band in range(bands):
result[band] = starlet_reconstruction(starlets[:, band], generation=generation, convolve2d=convolve2d)
result[band] = starlet_reconstruction(
starlets[:, band],
generation=generation,
convolve2d=convolve2d,
skip_scales=skip_scales,
)
return result


Expand All @@ -248,6 +262,7 @@ def get_multiresolution_support(
epsilon: float = 1e-1,
max_iter: int = 20,
image_type: str = "ground",
generation: int = 2,
) -> MultiResolutionSupport:
"""Calculate the multi-resolution support for a
dictionary of starlet coefficients.
Expand Down Expand Up @@ -295,7 +310,7 @@ def get_multiresolution_support(
# Calculate sigma_je, the standard deviation at
# each scale due to gaussian noise
noise_img = np.random.normal(size=image.shape)
noise_starlet = starlet_transform(noise_img, generation=1, scales=len(starlets) - 1)
noise_starlet = starlet_transform(noise_img, generation=generation, scales=len(starlets) - 1)
sigma_je = np.zeros((len(noise_starlet),))
for j, star in enumerate(noise_starlet):
sigma_je[j] = np.std(star)
Expand All @@ -309,7 +324,7 @@ def get_multiresolution_support(
if np.abs(sigma_i - last_sigma_i) / sigma_i < epsilon:
break
last_sigma_i = sigma_i
sigma_j = sigma_je
sigma_j = sigma_je * sigma_i
else:
# Sigma to use for significance at each scale
# Initially we use the input `sigma`
Expand Down
Loading