diff --git a/src/dolphin/unwrap/_unwrap.py b/src/dolphin/unwrap/_unwrap.py index 6e1c5fa6c..e786de104 100644 --- a/src/dolphin/unwrap/_unwrap.py +++ b/src/dolphin/unwrap/_unwrap.py @@ -28,7 +28,7 @@ from ._snaphu_py import grow_conncomp_snaphu, unwrap_snaphu_py from ._tophu import multiscale_unwrap from ._unwrap_3d import unwrap_spurt -from ._utils import create_combined_mask, set_nodata_values +from ._utils import create_combined_mask, get_convex_hull_mask, set_nodata_values from ._whirlwind import unwrap_whirlwind logger = logging.getLogger("dolphin") @@ -385,6 +385,15 @@ def unwrap( sim = io.load_gdal(similarity_filename, masked=True).filled(0) coherent_pixel_mask &= sim >= sim_cutoff + # Get only the convex hull of these good pixels + # This saves time by avoiding edges that will never get unwrapped well anyway + hull = get_convex_hull_mask( + coherent_pixel_mask, buffer_pixels=preproc_options.max_radius + ) + # Skip pixels which are outside the hull + weights = coherent_pixel_mask.copy() + # We skip pixels with weight = 1 and transfer over the input data + weights[~hull] = 1 logger.info(f"Interpolating {pre_interp_ifg_filename} -> {interp_ifg_filename}") modified_ifg = interpolate( ifg=pre_interp_ifg, diff --git a/src/dolphin/unwrap/_utils.py b/src/dolphin/unwrap/_utils.py index 79df2f15f..c8e6f3bb1 100644 --- a/src/dolphin/unwrap/_utils.py +++ b/src/dolphin/unwrap/_utils.py @@ -3,6 +3,7 @@ import logging from pathlib import Path +import numpy as np import rasterio as rio from dolphin import io @@ -149,3 +150,43 @@ def _zero_from_mask( like_filename=corr_filename, ) return zeroed_ifg_file, zeroed_corr_file + + +def get_convex_hull_mask(good_mask: np.ndarray, buffer_pixels: int = 0) -> np.ndarray: + """Get a boolean image of the convex hull of `True` points in `good_mask`.""" + import numpy as np + from scipy.ndimage import binary_dilation + from scipy.spatial import ConvexHull + + hull = ConvexHull(np.stack(np.where(good_mask)).T) + # Step 1: Create a grid of all coordinates in the image. + grid_rows, grid_cols = np.indices(good_mask.shape) + + # Step 2: Reshape the grid into a list of (row, col) points. + # This (N*M, 2) array represents every pixel in the image. + all_coords = np.vstack((grid_rows.ravel(), grid_cols.ravel())).T + + # Step 3: Use the hull's equations to find points inside. + # The hull is defined by a set of planes. For a point `x` to be inside the hull, + # it must satisfy `A @ x.T + b <= 0` for all planes. + # The `hull.equations` attribute stores `A` (plane normals) and `b` (offsets). + # We can perform this check for all coordinates at once with vectorized operations. + + # `A` contains the normal vectors of the hull's planes. + A = hull.equations[:, :2] + # `b` contains the offsets of the hull's planes. + b = hull.equations[:, 2] + + # Step 4: Perform the check for all coordinates. + # `A @ all_coords.T` finds the dot product for all points with all plane normals + # We add the offset `b` and check if the result is non-positive. + # `np.all(..., axis=0)` ensures the condition is met for ALL planes for each point. + # use a small tolerance (1e-6) to account for floating-point inaccuracies. + is_inside_hull_flat = np.all(A @ all_coords.T + b[:, np.newaxis] <= 1e-6, axis=0) + + hull_mask = is_inside_hull_flat.reshape(good_mask.shape) + if buffer_pixels: + hull_mask = binary_dilation( + hull_mask, structure=np.ones((3, 3)), iterations=buffer_pixels + ) + return hull_mask