|
15 | 15 | from openeo_gfmap.inference.model_inference import apply_model_inference_local
|
16 | 16 |
|
17 | 17 | from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
|
18 |
| -from worldcereal.openeo.inference import CroptypeClassifier |
| 18 | +from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier |
19 | 19 |
|
20 | 20 | TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
|
21 | 21 | TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"
|
|
40 | 40 | .astype("uint16")
|
41 | 41 | )
|
42 | 42 |
|
43 |
| - print("Running presto UDF locally") |
44 |
| - features = apply_feature_extractor_local( |
| 43 | + print("Get Presto cropland features") |
| 44 | + cropland_features = apply_feature_extractor_local( |
45 | 45 | PrestoFeatureExtractor,
|
46 | 46 | arr,
|
| 47 | + parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True}, |
| 48 | + ) |
| 49 | + |
| 50 | + print("Running cropland classification inference UDF locally") |
| 51 | + |
| 52 | + cropland_classification = apply_model_inference_local( |
| 53 | + CroplandClassifier, |
| 54 | + cropland_features, |
47 | 55 | parameters={
|
48 | 56 | EPSG_HARMONIZED_NAME: 32631,
|
49 | 57 | "ignore_dependencies": True,
|
50 |
| - "presto_model_url": PRESTO_URL, |
51 | 58 | },
|
52 | 59 | )
|
53 | 60 |
|
54 |
| - features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc") |
| 61 | + print("Get Presto croptype features") |
| 62 | + croptype_features = apply_feature_extractor_local( |
| 63 | + PrestoFeatureExtractor, |
| 64 | + arr, |
| 65 | + parameters={ |
| 66 | + EPSG_HARMONIZED_NAME: 32631, |
| 67 | + "ignore_dependencies": True, |
| 68 | + "presto_model_url": PRESTO_URL, |
| 69 | + }, |
| 70 | + ) |
55 | 71 |
|
56 |
| - print("Running classification inference UDF locally") |
| 72 | + print("Running croptype classification inference UDF locally") |
57 | 73 |
|
58 |
| - classification = apply_model_inference_local( |
| 74 | + croptype_classification = apply_model_inference_local( |
59 | 75 | CroptypeClassifier,
|
60 |
| - features, |
| 76 | + croptype_features, |
61 | 77 | parameters={
|
62 | 78 | EPSG_HARMONIZED_NAME: 32631,
|
63 | 79 | "ignore_dependencies": True,
|
64 | 80 | "classifier_url": CATBOOST_URL,
|
65 | 81 | },
|
66 | 82 | )
|
67 | 83 |
|
68 |
| - classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc") |
| 84 | + # Apply cropland mask -> on the backend this is done with mask process |
| 85 | + croptype_classification = croptype_classification.where( |
| 86 | + cropland_classification.sel(bands="classification") == 1, 0 |
| 87 | + ) |
| 88 | + |
| 89 | + croptype_classification.to_netcdf( |
| 90 | + Path("/vitodata/worldcereal/validation/internal_validation/") |
| 91 | + / "test_classification_croptype_local.nc" |
| 92 | + ) |
0 commit comments