diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6979c53f..9f4e8857 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -125,11 +125,11 @@ repos: pass_filenames: false additional_dependencies: [tomli, pyyaml] - # Add license header to the source files - - repo: local - hooks: - - id: add-license-header - name: Add License Header - entry: python .github/scripts/apply_license_header.py - language: python - files: \.py$ +# # Add license header to the source files +# - repo: local +# hooks: +# - id: add-license-header +# name: Add License Header +# entry: python .github/scripts/apply_license_header.py +# language: python +# files: \.py$ diff --git a/tests/test_coreg/test_base.py b/tests/test_coreg/test_base.py index a6d3d976..0c1a7414 100644 --- a/tests/test_coreg/test_base.py +++ b/tests/test_coreg/test_base.py @@ -21,6 +21,7 @@ import xdem from xdem import coreg, examples, misc, spatialstats from xdem._typing import NDArrayf +from xdem.coreg import BlockwiseCoreg from xdem.coreg.base import Coreg, apply_matrix, dict_key_to_str @@ -924,11 +925,8 @@ def test_blockwise_coreg_large_gaps(self) -> None: stats = blockwise.stats() - # We expect holes in the blockwise coregistration, so there should not be 64 "successful" blocks. - assert stats.shape[0] < 64 - - # Statistics are only calculated on finite values, so all of these should be finite as well. - assert np.all(np.isfinite(stats)) + # We expect holes in the blockwise coregistration, but not in stats due to nan padding for failing chunks + assert stats.shape[0] == 64 # Copy the TBA DEM and set a square portion to nodata tba = self.tba.copy() @@ -938,7 +936,7 @@ def test_blockwise_coreg_large_gaps(self) -> None: blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), 8, warn_failures=False) - # Align the DEM and apply the blockwise to a zero-array (to get the zshift) + # Align the DEM and apply blockwise to a zero-array (to get the z_shift) aligned = blockwise.fit(self.ref, tba).apply(tba) zshift, _ = blockwise.apply(np.zeros_like(tba.data), transform=tba.transform, crs=tba.crs) @@ -952,6 +950,39 @@ def test_blockwise_coreg_large_gaps(self) -> None: assert abs(np.nanmedian(ddem_pre)) > abs(np.nanmedian(ddem_post)) # assert np.nanstd(ddem_pre) > np.nanstd(ddem_post) + def test_failed_chunks_return_nan(self) -> None: + blockwise = BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=4) + blockwise.fit(**self.fit_params) + # Missing chunk 1 to simulate failure + blockwise._meta["step_meta"] = [meta for meta in blockwise._meta["step_meta"] if meta.get("i") != 1] + + result_df = blockwise.stats() + + # Check that chunk 1 (index 1) has NaN values for the statistics + assert np.isnan(result_df.loc[1, "inlier_count"]) + assert np.isnan(result_df.loc[1, "nmad"]) + assert np.isnan(result_df.loc[1, "median"]) + assert isinstance(result_df.loc[1, "center_x"], float) + assert isinstance(result_df.loc[1, "center_y"], float) + assert np.isnan(result_df.loc[1, "center_z"]) + assert np.isnan(result_df.loc[1, "x_off"]) + assert np.isnan(result_df.loc[1, "y_off"]) + assert np.isnan(result_df.loc[1, "z_off"]) + + def test_successful_chunks_return_values(self) -> None: + blockwise = BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=2) + blockwise.fit(**self.fit_params) + result_df = blockwise.stats() + + # Check that the correct statistics are returned for successful chunks + assert result_df.loc[0, "inlier_count"] == blockwise._meta["step_meta"][0]["inlier_count"] + assert result_df.loc[0, "nmad"] == blockwise._meta["step_meta"][0]["nmad"] + assert result_df.loc[0, "median"] == blockwise._meta["step_meta"][0]["median"] + + assert result_df.loc[1, "inlier_count"] == blockwise._meta["step_meta"][1]["inlier_count"] + assert result_df.loc[1, "nmad"] == blockwise._meta["step_meta"][1]["nmad"] + assert result_df.loc[1, "median"] == blockwise._meta["step_meta"][1]["median"] + class TestAffineManipulation: diff --git a/xdem/coreg/base.py b/xdem/coreg/base.py index 5222ddc3..a37d81f3 100644 --- a/xdem/coreg/base.py +++ b/xdem/coreg/base.py @@ -3036,6 +3036,7 @@ def __init__( super().__init__() self._meta: CoregDict = {"step_meta": []} + self._groups: NDArrayf = np.array([]) def fit( self: CoregType, @@ -3091,9 +3092,9 @@ def fit( else: mask = inlier_mask - groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape) + self._groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape) - indices = np.unique(groups) + indices = np.unique(self._groups) progress_bar = tqdm( total=indices.size, desc="Processing chunks", disable=logging.getLogger().getEffectiveLevel() > logging.INFO @@ -3108,7 +3109,7 @@ def process(i: int) -> dict[str, Any] | BaseException | None: * If it fails: The associated exception. * If the block is empty: None """ - group_mask = groups == i + group_mask = self._groups == i # Find the corresponding slice of the inlier_mask to subset the data rows, cols = np.where(group_mask) @@ -3272,24 +3273,44 @@ def to_points(self) -> NDArrayf: if len(self._meta["step_meta"]) == 0: raise AssertionError("No coreg results exist. Has '.fit()' been called?") points = np.empty(shape=(0, 3, 2)) - for meta in self._meta["step_meta"]: - self._restore_metadata(meta) - # x_coord, y_coord = rio.transform.xy(meta["transform"], meta["representative_row"], - # meta["representative_col"]) - x_coord, y_coord = meta["representative_x"], meta["representative_y"] + for i in range(self.subdivision): + # Try to restore the metadata for this chunk (if it succeeded) + chunk_meta = next((meta for meta in self._meta["step_meta"] if meta["i"] == i), None) - old_pos_arr = np.reshape([x_coord, y_coord, meta["representative_val"]], (1, 3)) + if chunk_meta is not None: + # Successful chunk: Retrieve the representative X, Y, Z coordinates + self._restore_metadata(chunk_meta) + x_coord, y_coord = chunk_meta["representative_x"], chunk_meta["representative_y"] + repr_val = chunk_meta["representative_val"] + else: + # Failed chunk: Calculate the approximate center using the group's bounds + rows, cols = np.where(self._groups == i) + center_row = (rows.min() + rows.max()) // 2 + center_col = (cols.min() + cols.max()) // 2 + + transform = self._meta["step_meta"][0]["transform"] # Assuming all chunks share a transform + x_coord, y_coord = rio.transform.xy(transform, center_row, center_col) + repr_val = np.nan # No valid Z value for failed chunks + + # Old position based on the calculated or retrieved coordinates + old_pos_arr = np.reshape([x_coord, y_coord, repr_val], (1, 3)) old_position = gpd.GeoDataFrame( geometry=gpd.points_from_xy(x=old_pos_arr[:, 0], y=old_pos_arr[:, 1], crs=None), data={"z": old_pos_arr[:, 2]}, ) - new_position = self.procstep.apply(old_position) - new_pos_arr = np.reshape( - [new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3) - ) + if chunk_meta is not None: + # Successful chunk: Apply the transformation + new_position = self.procstep.apply(old_position) + new_pos_arr = np.reshape( + [new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3) + ) + else: + # Failed chunk: Keep the new position the same as the old position (no transformation) + new_pos_arr = old_pos_arr.copy() + # Append the result points = np.append(points, np.dstack((old_pos_arr, new_pos_arr)), axis=0) return points @@ -3307,6 +3328,7 @@ def stats(self) -> pd.DataFrame: :raises ValueError: If no coregistration results exist yet. :returns: A dataframe of statistics for each chunk. + If a chunk fails (not present in `chunk_meta`), the statistics will be returned as `NaN`. """ points = self.to_points() @@ -3315,20 +3337,34 @@ def stats(self) -> pd.DataFrame: statistics: list[dict[str, Any]] = [] for i in range(points.shape[0]): if i not in chunk_meta: - continue - statistics.append( - { - "center_x": points[i, 0, 0], - "center_y": points[i, 1, 0], - "center_z": points[i, 2, 0], - "x_off": points[i, 0, 1] - points[i, 0, 0], - "y_off": points[i, 1, 1] - points[i, 1, 0], - "z_off": points[i, 2, 1] - points[i, 2, 0], - "inlier_count": chunk_meta[i]["inlier_count"], - "nmad": chunk_meta[i]["nmad"], - "median": chunk_meta[i]["median"], - } - ) + # For missing chunks, return NaN for all stats + statistics.append( + { + "center_x": points[i, 0, 0], + "center_y": points[i, 1, 0], + "center_z": points[i, 2, 0], + "x_off": np.nan, + "y_off": np.nan, + "z_off": np.nan, + "inlier_count": np.nan, + "nmad": np.nan, + "median": np.nan, + } + ) + else: + statistics.append( + { + "center_x": points[i, 0, 0], + "center_y": points[i, 1, 0], + "center_z": points[i, 2, 0], + "x_off": points[i, 0, 1] - points[i, 0, 0], + "y_off": points[i, 1, 1] - points[i, 1, 0], + "z_off": points[i, 2, 1] - points[i, 2, 0], + "inlier_count": chunk_meta[i]["inlier_count"], + "nmad": chunk_meta[i]["nmad"], + "median": chunk_meta[i]["median"], + } + ) stats_df = pd.DataFrame(statistics) stats_df.index.name = "chunk" @@ -3364,6 +3400,11 @@ def _apply_rst( raise NotImplementedError("Option `resample=False` not supported for coreg method BlockwiseCoreg.") points = self.to_points() + # Check for NaN values across both the old and new positions for each point + mask = ~np.isnan(points).any(axis=(1, 2)) + + # Filter out points where there are no NaN values + points = points[mask] bounds = _bounds(transform=transform, shape=elev.shape) resolution = _res(transform) @@ -3406,6 +3447,12 @@ def _apply_pts( """Apply the scaling model to a set of points.""" points = self.to_points() + # Check for NaN values across both the old and new positions for each point + mask = ~np.isnan(points).any(axis=(1, 2)) + + # Filter out points where there are no NaN values + points = points[mask] + new_coords = np.array([elev.geometry.x.values, elev.geometry.y.values, elev["z"].values]).T for dim in range(0, 3): @@ -3512,7 +3559,7 @@ def warp_dem( order = {"nearest": 0, "linear": 1, "cubic": 3} with warnings.catch_warnings(): - # An skimage warning that will hopefully be fixed soon. (2021-06-08) + # A skimage warning that will hopefully be fixed soon. (2021-06-08) warnings.filterwarnings("ignore", message="Passing `np.nan` to mean no clipping in np.clip") warped = skimage.transform.warp( image=np.where(dem_mask, np.nan, dem_arr),