Skip to content

Commit 595099c

Browse files
authored
Merge pull request #85 from WorldCereal/cropland-masking
Update crop type mapping workflow
2 parents 012d6b7 + 3437b23 commit 595099c

File tree

5 files changed

+243
-78
lines changed

5 files changed

+243
-78
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ dependencies:
1616
- openeo=0.29.0
1717
- pyarrow=16.1.0
1818
- python=3.10.0
19-
- pytorch=2.3.0
19+
- pytorch=2.3.1
2020
- rasterio=1.3.10
2121
- rioxarray=0.15.5
2222
- scikit-image=0.22.0

scripts/inference/cropland_mapping.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
if __name__ == "__main__":
1313
parser = argparse.ArgumentParser(
14-
prog="WC - Cropland Inference",
15-
description="Cropland inference using GFMAP, Presto and WorldCereal classifiers",
14+
prog="WC - Crop Mapping Inference",
15+
description="Crop Mapping inference using GFMAP, Presto and WorldCereal classifiers",
1616
)
1717

1818
parser.add_argument("minx", type=float, help="Minimum X coordinate (west)")
@@ -25,6 +25,11 @@
2525
default=4326,
2626
help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
2727
)
28+
parser.add_argument(
29+
"product",
30+
type=str,
31+
help="Product to generate. One of ['cropland', 'croptype']",
32+
)
2833
parser.add_argument(
2934
"start_date", type=str, help="Starting date for data extraction."
3035
)
@@ -46,6 +51,15 @@
4651
start_date = args.start_date
4752
end_date = args.end_date
4853

54+
product = args.product
55+
56+
# minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) # Small test
57+
# minx, miny, maxx, maxy = (664000, 5611134, 684000, 5631134) # Large test
58+
# epsg = 32631
59+
# start_date = "2020-11-01"
60+
# end_date = "2021-10-31"
61+
# product = "croptype"
62+
4963
spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg)
5064
temporal_extent = TemporalContext(start_date, end_date)
5165

@@ -56,7 +70,7 @@
5670
temporal_extent,
5771
backend_context,
5872
args.output_path,
59-
product_type=WorldCerealProduct.CROPLAND,
73+
product_type=WorldCerealProduct(product),
6074
out_format="GTiff",
6175
)
6276
logger.success("Job finished:\n\t%s", job_results)

scripts/inference/croptype_mapping_local.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from openeo_gfmap.inference.model_inference import apply_model_inference_local
1616

1717
from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
18-
from worldcereal.openeo.inference import CroptypeClassifier
18+
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier
1919

2020
TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
2121
TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"
@@ -40,29 +40,53 @@
4040
.astype("uint16")
4141
)
4242

43-
print("Running presto UDF locally")
44-
features = apply_feature_extractor_local(
43+
print("Get Presto cropland features")
44+
cropland_features = apply_feature_extractor_local(
4545
PrestoFeatureExtractor,
4646
arr,
47+
parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True},
48+
)
49+
50+
print("Running cropland classification inference UDF locally")
51+
52+
cropland_classification = apply_model_inference_local(
53+
CroplandClassifier,
54+
cropland_features,
4755
parameters={
4856
EPSG_HARMONIZED_NAME: 32631,
4957
"ignore_dependencies": True,
50-
"presto_model_url": PRESTO_URL,
5158
},
5259
)
5360

54-
features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc")
61+
print("Get Presto croptype features")
62+
croptype_features = apply_feature_extractor_local(
63+
PrestoFeatureExtractor,
64+
arr,
65+
parameters={
66+
EPSG_HARMONIZED_NAME: 32631,
67+
"ignore_dependencies": True,
68+
"presto_model_url": PRESTO_URL,
69+
},
70+
)
5571

56-
print("Running classification inference UDF locally")
72+
print("Running croptype classification inference UDF locally")
5773

58-
classification = apply_model_inference_local(
74+
croptype_classification = apply_model_inference_local(
5975
CroptypeClassifier,
60-
features,
76+
croptype_features,
6177
parameters={
6278
EPSG_HARMONIZED_NAME: 32631,
6379
"ignore_dependencies": True,
6480
"classifier_url": CATBOOST_URL,
6581
},
6682
)
6783

68-
classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc")
84+
# Apply cropland mask -> on the backend this is done with mask process
85+
croptype_classification = croptype_classification.where(
86+
cropland_classification.sel(bands="classification") == 1, 0
87+
)
88+
89+
croptype_classification.to_netcdf(
90+
Path("/vitodata/worldcereal/validation/internal_validation/")
91+
/ "test_classification_croptype_local.nc"
92+
)

0 commit comments

Comments
 (0)