diff --git a/python/lsst/scarlet/lite/detect.py b/python/lsst/scarlet/lite/detect.py index a5ac0aa..94047c1 100644 --- a/python/lsst/scarlet/lite/detect.py +++ b/python/lsst/scarlet/lite/detect.py @@ -25,7 +25,8 @@ from typing import Sequence import numpy as np -from lsst.scarlet.lite.detect_pybind11 import Footprint, get_footprints # type: ignore +from dataclasses import dataclass +from lsst.scarlet.lite.detect_pybind11 import Footprint, get_footprints, Peak # type: ignore from .bbox import Box, overlapped_slices from .image import Image @@ -246,20 +247,50 @@ def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int = return (support.support * _coeffs).astype(images.dtype) +def get_support( + image: np.ndarray, + sigma: float | None = None, + epsilon: float = 1e-1 +): + if sigma is None: + sigma = np.median(np.absolute(image - np.median(image))) + last_sigma = sigma + for _ in range(20): + m = np.abs(image) > 5 * sigma + s = ~m + sigma = np.std(image*s.astype(int)) + if np.abs(sigma - last_sigma)/sigma < epsilon: + break + return m, sigma + + +@dataclass +class DetectionResult: + image_sigmas: list[float] + detection: np.ndarray + starlets: np.ndarray + peak_footprints: list[Footprint] + starlet_sigma: float + starlet_support: np.ndarray + detection_sigma: float + detection_support: np.ndarray + footprints: list[Footprint] + peaks: list[Peak] + dropped_peaks: list[Peak] + + def detect_footprints( images: np.ndarray, - variance: np.ndarray, - scales: int = 1, + detection: np.ndarray | None = None, + starlet_scale: int = 1, generation: int = 2, origin: tuple[int, int] | None = None, min_separation: float = 4, - min_area: int = 4, - peak_thresh: float = 5, - footprint_thresh: float = 5, - find_peaks: bool = True, - remove_high_freq: bool = True, - min_pixel_detect: int = 1, -) -> list[Footprint]: + min_starlet_area: int = 4, + peak_thresh: float = 3, + footprint_thresh: float = 2, + min_footprint_area: int = 8, +) -> DetectionResult: """Detect footprints in an image Parameters @@ -267,72 +298,108 @@ def detect_footprints( images: The array of images with shape `(bands, Ny, Nx)` for which to calculate wavelet coefficients. - variance: - An array of variances with the same shape as `images`. - scales: - The maximum number of wavelet scales to use. - If `remove_high_freq` is `False`, then this argument is ignored. + detection: + An optional detection image. If `None` then one is created. + starlet_scale: + The scale of the starlet transform to use for detection. + All of the default configs are tuned for `starlet_scale=1`. + If using `starlet_scale=2` then it is recommended to use: + - `min_starlet_area=1` + - `peak_thresh=2` generation: The generation of the starlet transform to use. - If `remove_high_freq` is `False`, then this argument is ignored. origin: The location (y, x) of the lower corner of the image. min_separation: The minimum separation between peaks in pixels. - min_area: - The minimum area of a footprint in pixels. + min_starlet_area: + The minimum area of a footprint in starlet space in pixels. peak_thresh: The threshold for peak detection. footprint_thresh: The threshold for footprint detection. - find_peaks: - If `True`, then detect peaks in the detection image, - otherwise only the footprints are returned. - remove_high_freq: - If `True`, then remove high frequency wavelet coefficients - before detecting peaks. - min_pixel_detect: - The minimum number of bands that must be above the - detection threshold for a pixel to be included in a footprint. + min_footprint_area: + The minimum area of a footprint in the image in pixels. """ if origin is None: origin = (0, 0) - if remove_high_freq: - # Build the wavelet coefficients - wavelets = get_wavelets( - images, - variance, - scales=scales, - generation=generation, - ) - # Remove the high frequency wavelets. - # This has the effect of preventing high frequency noise - # from interfering with the detection of peak positions. - wavelets[0] = 0 - # Reconstruct the image from the remaining wavelet coefficients - _images = multiband_starlet_reconstruction( - wavelets, - generation=generation, - ) - else: - _images = images - # Build a SNR weighted detection image - sigma = np.median(np.sqrt(variance), axis=(1, 2)) / 2 - detection = np.sum(_images / sigma[:, None, None], axis=0) - if min_pixel_detect > 1: - mask = np.sum(images > 0, axis=0) >= min_pixel_detect - detection[~mask] = 0 + y0, x0 = origin + + if detection is None: + # Find the standard deviation of the noise in each band + sigmas = [] + for image in images: + support, sigma = get_support(image) + sigmas.append(sigma) + + # Create the variance weighted detection image + detection = np.sum([image/sigma for image, sigma in zip(images, sigmas)], axis=0) + + # Use the chosen scale of starlets to act as a compensated filter + # for detection + starlets = starlet_transform(detection, scales=starlet_scale+1, generation=generation) + + # Estimate the noise in the detection image + starlet_sigma = np.median(np.absolute(detection - np.median(detection))) + starlet_support = get_multiresolution_support( + image=detection, + starlets=starlets, + sigma=starlet_sigma, + sigma_scaling=3, + ) + # Detect peaks on the detection image - footprints = get_footprints( - detection, + peak_footprints = get_footprints( + starlets[starlet_scale], min_separation, - min_area, - peak_thresh, - footprint_thresh, - find_peaks, - origin[0], - origin[1], + min_starlet_area, + peak_thresh*starlet_support.sigma[starlet_scale], + footprint_thresh*starlet_support.sigma[starlet_scale], + True, + y0, + x0, ) - return footprints + # Extract the peaks from the detection footprints + peaks = [peak for fp in peak_footprints for peak in fp.peaks] + + # Create a new set of footprints on the images. + # Detection is made on the boolean image, so the parameters + # that we pass to it are simplified. + support, sigma = get_support(detection) + footprints = get_footprints( + detection > sigma, + 0, + min_footprint_area, + 0, + 0, + False, + y0, + x0, + ) + + # Create an image of all of the footprints + footprint_image = footprints_to_image(footprints, Box(detection.shape, origin=origin)) + dropped_peaks = [] + for peak in peaks: + footprint_index = footprint_image.at(peak.y, peak.x) - 1 + if footprint_index >= 0: + footprints[footprint_index].add_peak(peak) + else: + dropped_peaks.append(peak) + logger.warning(f"Peak at ({peak.y}, {peak.x}) not in footprint") + + return DetectionResult( + image_sigmas=sigmas, + detection=detection, + starlets=starlets, + peak_footprints=peak_footprints, + starlet_sigma=starlet_sigma, + starlet_support=starlet_support, + detection_sigma=sigma, + detection_support=support, + footprints=footprints, + peaks=peaks, + dropped_peaks=dropped_peaks, + ) diff --git a/python/lsst/scarlet/lite/wavelet.py b/python/lsst/scarlet/lite/wavelet.py index c1b0161..b19ac33 100644 --- a/python/lsst/scarlet/lite/wavelet.py +++ b/python/lsst/scarlet/lite/wavelet.py @@ -184,6 +184,7 @@ def starlet_reconstruction( starlets: np.ndarray, generation: int = 2, convolve2d: Callable | None = None, + skip_scales: list[int] | None = None, ) -> np.ndarray: """Reconstruct an image from a dictionary of starlets @@ -197,6 +198,9 @@ def starlet_reconstruction( convolve2d: The filter function to use to convolve the image with starlets in 2D. + skip_scales: + List of scales to skip in the reconstruction. + This can be used to remove noise at small scales. Returns ------- @@ -207,11 +211,15 @@ def starlet_reconstruction( return np.sum(starlets, axis=0) if convolve2d is None: convolve2d = bspline_convolve + if skip_scales is None: + skip_scales = [] scales = len(starlets) - 1 c = starlets[-1] for i in range(1, scales + 1): j = scales - i + if j in skip_scales: + continue cj = convolve2d(c, j) c = cj + starlets[j] return c @@ -221,6 +229,7 @@ def multiband_starlet_reconstruction( starlets: np.ndarray, generation: int = 2, convolve2d: Callable | None = None, + skip_scales: list[int] | None = None, ) -> np.ndarray: """Reconstruct a multiband image. @@ -230,7 +239,12 @@ def multiband_starlet_reconstruction( _, bands, width, height = starlets.shape result = np.zeros((bands, width, height), dtype=starlets.dtype) for band in range(bands): - result[band] = starlet_reconstruction(starlets[:, band], generation=generation, convolve2d=convolve2d) + result[band] = starlet_reconstruction( + starlets[:, band], + generation=generation, + convolve2d=convolve2d, + skip_scales=skip_scales, + ) return result @@ -248,6 +262,7 @@ def get_multiresolution_support( epsilon: float = 1e-1, max_iter: int = 20, image_type: str = "ground", + generation: int = 2, ) -> MultiResolutionSupport: """Calculate the multi-resolution support for a dictionary of starlet coefficients. @@ -295,7 +310,7 @@ def get_multiresolution_support( # Calculate sigma_je, the standard deviation at # each scale due to gaussian noise noise_img = np.random.normal(size=image.shape) - noise_starlet = starlet_transform(noise_img, generation=1, scales=len(starlets) - 1) + noise_starlet = starlet_transform(noise_img, generation=generation, scales=len(starlets) - 1) sigma_je = np.zeros((len(noise_starlet),)) for j, star in enumerate(noise_starlet): sigma_je[j] = np.std(star) @@ -309,7 +324,7 @@ def get_multiresolution_support( if np.abs(sigma_i - last_sigma_i) / sigma_i < epsilon: break last_sigma_i = sigma_i - sigma_j = sigma_je + sigma_j = sigma_je * sigma_i else: # Sigma to use for significance at each scale # Initially we use the input `sigma`