Skip to content

Commit

Permalink
Merge pull request #188 from WorldCereal/simplify-majority-vote
Browse files Browse the repository at this point in the history
get rid of conf_threshold parameter as it may introduce "undecided" p…
  • Loading branch information
jdegerickx authored Oct 15, 2024
2 parents db818e4 + eb7f398 commit 9c60085
Show file tree
Hide file tree
Showing 6 changed files with 3 additions and 33 deletions.
2 changes: 0 additions & 2 deletions notebooks/worldcereal_v1_demo_custom_cropland.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@
"postprocess_method = \"majority_vote\"\n",
"# Additiona parameters for the majority vote method:\n",
"kernel_size = 3 # default = 5\n",
"conf_threshold = 60 # default = 30\n",
"# Do you want to save the intermediate results (before applying the postprocessing)\n",
"save_intermediate = True #default is False\n",
"# Do you want to save all class probabilities in the final product? (default is False)\n",
Expand All @@ -499,7 +498,6 @@
"postprocess_parameters = PostprocessParameters(enable=postprocess_result,\n",
" method=postprocess_method,\n",
" kernel_size=kernel_size,\n",
" conf_threshold=conf_threshold,\n",
" save_intermediate=save_intermediate,\n",
" keep_class_probs=keep_class_probs)\n",
"\n",
Expand Down
2 changes: 0 additions & 2 deletions notebooks/worldcereal_v1_demo_custom_croptype.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@
"postprocess_method = \"majority_vote\"\n",
"# Additiona parameters for the majority vote method:\n",
"kernel_size = 5 # default = 5\n",
"conf_threshold = 30 # default = 30\n",
"# Do you want to save the intermediate results (before applying the postprocessing)\n",
"save_intermediate = True #default is False\n",
"# Do you want to save all class probabilities in the final product? (default is False)\n",
Expand All @@ -517,7 +516,6 @@
"postprocess_parameters = PostprocessParameters(enable=postprocess_result,\n",
" method=postprocess_method,\n",
" kernel_size=kernel_size,\n",
" conf_threshold=conf_threshold,\n",
" save_intermediate=save_intermediate,\n",
" keep_class_probs=keep_class_probs)\n",
"\n",
Expand Down
2 changes: 0 additions & 2 deletions notebooks/worldcereal_v1_demo_default_cropland_EXTENDED.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@
"postprocess_method = \"majority_vote\"\n",
"# Additiona parameters for the majority vote method:\n",
"kernel_size = 3 # default = 5\n",
"conf_threshold = 60 # default = 30\n",
"# Do you want to save the intermediate results (before applying the postprocessing)\n",
"save_intermediate = True #default is False\n",
"# Do you want to save all class probabilities in the final product? (default is False)\n",
Expand All @@ -231,7 +230,6 @@
"postprocess_parameters = PostprocessParameters(enable=postprocess_result,\n",
" method=postprocess_method,\n",
" kernel_size=kernel_size,\n",
" conf_threshold=conf_threshold,\n",
" save_intermediate=save_intermediate,\n",
" keep_class_probs=keep_class_probs)\n",
"\n",
Expand Down
14 changes: 3 additions & 11 deletions src/worldcereal/openeo/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ def majority_vote(
base_labels: xr.DataArray,
max_probabilities: xr.DataArray,
kernel_size: int,
conf_threshold: int,
) -> xr.DataArray:
"""Majority vote is performed using a sliding local kernel.
For each pixel, the voting of a final class is done from
neighbours values weighted with the confidence threshold.
For each pixel, the voting of a final class is done by counting
neighbours values.
Pixels that have one of the specified excluded values are
excluded in the voting process and are unchanged.
Expand All @@ -55,8 +54,6 @@ def majority_vote(
The original probabilities of the winning class (ranging between 0 and 100).
kernel_size : int
The size of the kernel used for the neighbour around the pixel.
conf_threshold : int
Pixels under this confidence threshold do not count into the voting process.
Returns
-------
Expand Down Expand Up @@ -93,9 +90,6 @@ def majority_vote(
# Take the binary mask of the interest class, and multiply by the probabilities
class_mask = ((prediction == cls_value) * probability).astype(np.uint16)

# Sets to 0 the class scores where the threshold is lower
class_mask[probability <= conf_threshold] = 0

# Set to 0 the class scores where the label is excluded
for excluded_value in cls.EXCLUDED_VALUES:
class_mask[prediction == excluded_value] = 0
Expand Down Expand Up @@ -156,7 +150,7 @@ def majority_vote(
# Setting excluded values back to their original values
for excluded_value in cls.EXCLUDED_VALUES:
aggregated_predictions[prediction == excluded_value] = excluded_value
aggregated_probabilities[prediction == excluded_value] = cls.NODATA
aggregated_probabilities[prediction == excluded_value] = excluded_value

return xr.DataArray(
np.stack((aggregated_predictions, aggregated_probabilities)),
Expand Down Expand Up @@ -286,13 +280,11 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
elif self._parameters.get("method") == "majority_vote":

kernel_size = self._parameters.get("kernel_size")
conf_threshold = self._parameters.get("conf_threshold")

new_labels = PostProcessor.majority_vote(
inarr.sel(bands="classification"),
inarr.sel(bands="probability"),
kernel_size=kernel_size,
conf_threshold=conf_threshold,
)

# Append the per-class probabalities if required
Expand Down
7 changes: 0 additions & 7 deletions src/worldcereal/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ class PostprocessParameters(BaseModel):
The method to use for postprocessing. Must be one of ["smooth_probabilities", "majority_vote"]
kernel_size: int (default=5)
Used for majority vote postprocessing. Must be smaller than 25.
conf_threshold: int (default=30)
Used for majority vote postprocessing. Must be between 0 and 100.
save_intermediate: bool (default=False)
Whether to save intermediate results (before applying the postprocessing).
The intermediate results will be saved in the GeoTiff format.
Expand All @@ -180,7 +178,6 @@ class PostprocessParameters(BaseModel):
enable: bool = Field(default=True)
method: str = Field(default="smooth_probabilities")
kernel_size: int = Field(default=5)
conf_threshold: int = Field(default=30)
save_intermediate: bool = Field(default=False)
keep_class_probs: bool = Field(default=False)

Expand Down Expand Up @@ -213,9 +210,5 @@ def check_parameters(self):
raise ValueError(
f"Kernel size must be smaller than 25, got {self.kernel_size}"
)
if self.conf_threshold < 0 or self.conf_threshold > 100:
raise ValueError(
f"Confidence threshold must be between 0 and 100, got {self.conf_threshold}"
)

return self
9 changes: 0 additions & 9 deletions tests/worldcerealtests/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_cropland_postprocessing_majority_vote(WorldCerealCroplandClassification
"lookup_table": lookup_table,
"method": "majority_vote",
"kernel_size": 7,
"conf_threshold": 30,
},
)

Expand Down Expand Up @@ -90,7 +89,6 @@ def test_croptype_postprocessing_majority_vote(WorldCerealCroptypeClassification
"lookup_table": lookup_table,
"method": "majority_vote",
"kernel_size": 7,
"conf_threshold": 30,
},
)

Expand All @@ -103,7 +101,6 @@ def test_postprocessing_parameters():
"enable": True,
"method": "smooth_probabilities",
"kernel_size": 5,
"conf_threshold": 30,
"save_intermediate": False,
"keep_class_probs": False,
}
Expand All @@ -118,12 +115,6 @@ def test_postprocessing_parameters():
with pytest.raises(ValueError):
PostprocessParameters(**params)

# This one should fail with invalid conf_threshold
params["kernel_size"] = 5
params["conf_threshold"] = 101
with pytest.raises(ValueError):
PostprocessParameters(**params)

# This one should fail with invalid method
params["method"] = "test"
with pytest.raises(ValueError):
Expand Down

0 comments on commit 9c60085

Please sign in to comment.