From 22aad4aa9a004432816e52794c6fbdb9fee57890 Mon Sep 17 00:00:00 2001 From: Romain Hugonnet Date: Mon, 9 Oct 2023 16:58:01 -0500 Subject: [PATCH] Improve tests for subsampling in NuthKaab --- tests/test_coreg/test_base.py | 5 +++-- xdem/coreg/base.py | 2 +- xdem/examples.py | 4 ++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_coreg/test_base.py b/tests/test_coreg/test_base.py index b5c1427d..2e0bfe73 100644 --- a/tests/test_coreg/test_base.py +++ b/tests/test_coreg/test_base.py @@ -134,10 +134,9 @@ def test_get_subsample_on_valid_mask(self, subsample: float | int) -> None: subsample_val = subsample assert np.count_nonzero(subsample_mask) == min(subsample_val, np.count_nonzero(valid_mask)) - # TODO: Activate NuthKaab once subsampling there is made consistent all_coregs = [ coreg.VerticalShift, - # coreg.NuthKaab, + coreg.NuthKaab, coreg.ICP, coreg.Deramp, coreg.TerrainBias, @@ -156,6 +155,8 @@ def test_subsample(self, coreg: Callable) -> None: # type: ignore # But can be overridden during fit coreg_full.fit(**self.fit_params, subsample=10000, random_state=42) assert coreg_full._meta["subsample"] == 10000 + # Check that the random state is properly set when subsampling explicitly or implicitly + assert coreg_full._meta["random_state"] == 42 # Test subsampled vertical shift correction coreg_sub = coreg(subsample=0.1) diff --git a/xdem/coreg/base.py b/xdem/coreg/base.py index e750ea66..24dad933 100644 --- a/xdem/coreg/base.py +++ b/xdem/coreg/base.py @@ -822,7 +822,7 @@ def fit( # In any case, override! self._meta["subsample"] = subsample - # Save random_state is a subsample is used + # Save random_state if a subsample is used if self._meta["subsample"] != 1: self._meta["random_state"] = random_state diff --git a/xdem/examples.py b/xdem/examples.py index 01769e19..cf56c6ed 100644 --- a/xdem/examples.py +++ b/xdem/examples.py @@ -103,6 +103,10 @@ def process_coregistered_examples(name: str, overwrite: bool = False) -> None: nuth_kaab = xdem.coreg.NuthKaab() nuth_kaab.fit(reference_raster, to_be_aligned_raster, inlier_mask=inlier_mask, random_state=42) + + # Check that random state is respected + assert nuth_kaab._meta["random_state"] == 42 + aligned_raster = nuth_kaab.apply(to_be_aligned_raster, resample=True) diff = reference_raster - aligned_raster