Skip to content

Commit

Permalink
Now saves uint16 values with 0-1 for crop/no-crop and 0-100 for proba…
Browse files Browse the repository at this point in the history
…bilities
  • Loading branch information
GriffinBabe committed Jun 20, 2024
1 parent ea4317a commit 1166fa0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 4 additions & 0 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from openeo_gfmap.backend import Backend, BackendContext
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference
from openeo_gfmap.preprocessing.scaling import compress_uint16

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
Expand Down Expand Up @@ -105,6 +106,9 @@
],
)

# Cast to uint16
classes = compress_uint16(classes)

classes.execute_batch(
outputfile=args.output_path,
out_format="GTiff",
Expand Down
11 changes: 7 additions & 4 deletions src/worldcereal/openeo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ def predict(self, features: np.ndarray) -> np.ndarray:
# Extract all prediction values and convert them to binary labels
prediction_values = [sublist["True"] for sublist in outputs[1]]
binary_labels = np.array(prediction_values) >= threshold
binary_labels = binary_labels.astype(int)
binary_labels = binary_labels.astype("uint8")

return np.stack([binary_labels, prediction_values], axis=0).astype(np.float32)
prediction_values = np.array(prediction_values) * 100.0
prediction_values = np.round(prediction_values).astype("uint8")

return np.stack([binary_labels, prediction_values], axis=0)

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH)
Expand All @@ -62,9 +65,9 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
self.onnx_session = self.load_ort_session(classifier_url)

# Run catboost classification
self.logger.info(f"Catboost classification with input shape: {inarr.shape}")
self.logger.info("Catboost classification with input shape: %s", inarr.shape)
classification = self.predict(inarr.values)
self.logger.info(f"Classification done with shape: {classification.shape}")
self.logger.info("Classification done with shape: %s", inarr.shape)

classification = xr.DataArray(
classification.reshape((2, len(x_coords), len(y_coords))),
Expand Down

0 comments on commit 1166fa0

Please sign in to comment.