From eb7f398ba5125cd0d1f02fcc3c5463989146fa5c Mon Sep 17 00:00:00 2001 From: DEGERICJ Date: Tue, 15 Oct 2024 13:25:35 +0200 Subject: [PATCH] get rid of conf_threshold parameter as it may introduce "undecided" pixels --- .../worldcereal_v1_demo_custom_cropland.ipynb | 2 -- .../worldcereal_v1_demo_custom_croptype.ipynb | 2 -- ...dcereal_v1_demo_default_cropland_EXTENDED.ipynb | 2 -- src/worldcereal/openeo/postprocess.py | 14 +++----------- src/worldcereal/parameters.py | 7 ------- tests/worldcerealtests/test_postprocessing.py | 9 --------- 6 files changed, 3 insertions(+), 33 deletions(-) diff --git a/notebooks/worldcereal_v1_demo_custom_cropland.ipynb b/notebooks/worldcereal_v1_demo_custom_cropland.ipynb index 76bb973e..6fa4aabf 100644 --- a/notebooks/worldcereal_v1_demo_custom_cropland.ipynb +++ b/notebooks/worldcereal_v1_demo_custom_cropland.ipynb @@ -490,7 +490,6 @@ "postprocess_method = \"majority_vote\"\n", "# Additiona parameters for the majority vote method:\n", "kernel_size = 3 # default = 5\n", - "conf_threshold = 60 # default = 30\n", "# Do you want to save the intermediate results (before applying the postprocessing)\n", "save_intermediate = True #default is False\n", "# Do you want to save all class probabilities in the final product? (default is False)\n", @@ -499,7 +498,6 @@ "postprocess_parameters = PostprocessParameters(enable=postprocess_result,\n", " method=postprocess_method,\n", " kernel_size=kernel_size,\n", - " conf_threshold=conf_threshold,\n", " save_intermediate=save_intermediate,\n", " keep_class_probs=keep_class_probs)\n", "\n", diff --git a/notebooks/worldcereal_v1_demo_custom_croptype.ipynb b/notebooks/worldcereal_v1_demo_custom_croptype.ipynb index 32071b05..5aa4730b 100644 --- a/notebooks/worldcereal_v1_demo_custom_croptype.ipynb +++ b/notebooks/worldcereal_v1_demo_custom_croptype.ipynb @@ -508,7 +508,6 @@ "postprocess_method = \"majority_vote\"\n", "# Additiona parameters for the majority vote method:\n", "kernel_size = 5 # default = 5\n", - "conf_threshold = 30 # default = 30\n", "# Do you want to save the intermediate results (before applying the postprocessing)\n", "save_intermediate = True #default is False\n", "# Do you want to save all class probabilities in the final product? (default is False)\n", @@ -517,7 +516,6 @@ "postprocess_parameters = PostprocessParameters(enable=postprocess_result,\n", " method=postprocess_method,\n", " kernel_size=kernel_size,\n", - " conf_threshold=conf_threshold,\n", " save_intermediate=save_intermediate,\n", " keep_class_probs=keep_class_probs)\n", "\n", diff --git a/notebooks/worldcereal_v1_demo_default_cropland_EXTENDED.ipynb b/notebooks/worldcereal_v1_demo_default_cropland_EXTENDED.ipynb index 59a41b02..ca5b377d 100644 --- a/notebooks/worldcereal_v1_demo_default_cropland_EXTENDED.ipynb +++ b/notebooks/worldcereal_v1_demo_default_cropland_EXTENDED.ipynb @@ -222,7 +222,6 @@ "postprocess_method = \"majority_vote\"\n", "# Additiona parameters for the majority vote method:\n", "kernel_size = 3 # default = 5\n", - "conf_threshold = 60 # default = 30\n", "# Do you want to save the intermediate results (before applying the postprocessing)\n", "save_intermediate = True #default is False\n", "# Do you want to save all class probabilities in the final product? (default is False)\n", @@ -231,7 +230,6 @@ "postprocess_parameters = PostprocessParameters(enable=postprocess_result,\n", " method=postprocess_method,\n", " kernel_size=kernel_size,\n", - " conf_threshold=conf_threshold,\n", " save_intermediate=save_intermediate,\n", " keep_class_probs=keep_class_probs)\n", "\n", diff --git a/src/worldcereal/openeo/postprocess.py b/src/worldcereal/openeo/postprocess.py index 33d85575..b9411142 100644 --- a/src/worldcereal/openeo/postprocess.py +++ b/src/worldcereal/openeo/postprocess.py @@ -33,11 +33,10 @@ def majority_vote( base_labels: xr.DataArray, max_probabilities: xr.DataArray, kernel_size: int, - conf_threshold: int, ) -> xr.DataArray: """Majority vote is performed using a sliding local kernel. - For each pixel, the voting of a final class is done from - neighbours values weighted with the confidence threshold. + For each pixel, the voting of a final class is done by counting + neighbours values. Pixels that have one of the specified excluded values are excluded in the voting process and are unchanged. @@ -55,8 +54,6 @@ def majority_vote( The original probabilities of the winning class (ranging between 0 and 100). kernel_size : int The size of the kernel used for the neighbour around the pixel. - conf_threshold : int - Pixels under this confidence threshold do not count into the voting process. Returns ------- @@ -93,9 +90,6 @@ def majority_vote( # Take the binary mask of the interest class, and multiply by the probabilities class_mask = ((prediction == cls_value) * probability).astype(np.uint16) - # Sets to 0 the class scores where the threshold is lower - class_mask[probability <= conf_threshold] = 0 - # Set to 0 the class scores where the label is excluded for excluded_value in cls.EXCLUDED_VALUES: class_mask[prediction == excluded_value] = 0 @@ -156,7 +150,7 @@ def majority_vote( # Setting excluded values back to their original values for excluded_value in cls.EXCLUDED_VALUES: aggregated_predictions[prediction == excluded_value] = excluded_value - aggregated_probabilities[prediction == excluded_value] = cls.NODATA + aggregated_probabilities[prediction == excluded_value] = excluded_value return xr.DataArray( np.stack((aggregated_predictions, aggregated_probabilities)), @@ -286,13 +280,11 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: elif self._parameters.get("method") == "majority_vote": kernel_size = self._parameters.get("kernel_size") - conf_threshold = self._parameters.get("conf_threshold") new_labels = PostProcessor.majority_vote( inarr.sel(bands="classification"), inarr.sel(bands="probability"), kernel_size=kernel_size, - conf_threshold=conf_threshold, ) # Append the per-class probabalities if required diff --git a/src/worldcereal/parameters.py b/src/worldcereal/parameters.py index 7fbc1606..69a1fa44 100644 --- a/src/worldcereal/parameters.py +++ b/src/worldcereal/parameters.py @@ -168,8 +168,6 @@ class PostprocessParameters(BaseModel): The method to use for postprocessing. Must be one of ["smooth_probabilities", "majority_vote"] kernel_size: int (default=5) Used for majority vote postprocessing. Must be smaller than 25. - conf_threshold: int (default=30) - Used for majority vote postprocessing. Must be between 0 and 100. save_intermediate: bool (default=False) Whether to save intermediate results (before applying the postprocessing). The intermediate results will be saved in the GeoTiff format. @@ -180,7 +178,6 @@ class PostprocessParameters(BaseModel): enable: bool = Field(default=True) method: str = Field(default="smooth_probabilities") kernel_size: int = Field(default=5) - conf_threshold: int = Field(default=30) save_intermediate: bool = Field(default=False) keep_class_probs: bool = Field(default=False) @@ -213,9 +210,5 @@ def check_parameters(self): raise ValueError( f"Kernel size must be smaller than 25, got {self.kernel_size}" ) - if self.conf_threshold < 0 or self.conf_threshold > 100: - raise ValueError( - f"Confidence threshold must be between 0 and 100, got {self.conf_threshold}" - ) return self diff --git a/tests/worldcerealtests/test_postprocessing.py b/tests/worldcerealtests/test_postprocessing.py index 022774e0..66112a88 100644 --- a/tests/worldcerealtests/test_postprocessing.py +++ b/tests/worldcerealtests/test_postprocessing.py @@ -50,7 +50,6 @@ def test_cropland_postprocessing_majority_vote(WorldCerealCroplandClassification "lookup_table": lookup_table, "method": "majority_vote", "kernel_size": 7, - "conf_threshold": 30, }, ) @@ -90,7 +89,6 @@ def test_croptype_postprocessing_majority_vote(WorldCerealCroptypeClassification "lookup_table": lookup_table, "method": "majority_vote", "kernel_size": 7, - "conf_threshold": 30, }, ) @@ -103,7 +101,6 @@ def test_postprocessing_parameters(): "enable": True, "method": "smooth_probabilities", "kernel_size": 5, - "conf_threshold": 30, "save_intermediate": False, "keep_class_probs": False, } @@ -118,12 +115,6 @@ def test_postprocessing_parameters(): with pytest.raises(ValueError): PostprocessParameters(**params) - # This one should fail with invalid conf_threshold - params["kernel_size"] = 5 - params["conf_threshold"] = 101 - with pytest.raises(ValueError): - PostprocessParameters(**params) - # This one should fail with invalid method params["method"] = "test" with pytest.raises(ValueError):