From 6f901b27d00fc7283e098f226c2d639bb2a0ff42 Mon Sep 17 00:00:00 2001 From: Darius Couchard Date: Mon, 24 Jun 2024 09:49:52 +0200 Subject: [PATCH 1/4] Added an enum for worldcereal products --- scripts/inference/cropland_mapping.py | 2 +- src/worldcereal/job.py | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index 1a2bf5b4..4d106a55 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -56,5 +56,5 @@ backend_context, args.output_path, product="cropland", - format="GTiff", + out_format="GTiff", ) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 934f77d9..fcc6408b 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -1,3 +1,5 @@ +"""Executing inference jobs on the OpenEO backend.""" +from enum import Enum from pathlib import Path from typing import Union @@ -14,13 +16,20 @@ ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" +class WorldCerealProduct(Enum): + """Enum to define the different WorldCereal products.""" + + CROPLAND = "cropland" + CROPTPE = "croptype" + + def generate_map( spatial_extent: BoundingBoxExtent, temporal_extent: TemporalContext, backend_context: BackendContext, output_path: Union[Path, str], - product: str = "cropland", - format: str = "GTiff", + product: WorldCerealProduct = WorldCerealProduct.CROPLAND, + out_format: str = "GTiff", ): """Main function to generate a WorldCereal product. @@ -69,15 +78,15 @@ def generate_map( ], ) - if product == "cropland": + if product == WorldCerealProduct.CROPLAND: # initiate default cropland model model_inference_class = CroplandClassifier model_inference_parameters = {} else: raise ValueError(f"Product {product} not supported.") - if format not in ["GTiff", "NetCDF"]: - raise ValueError(f"Format {format} not supported.") + if out_format not in ["GTiff", "NetCDF"]: + raise ValueError(f"Format {out_format} not supported.") classes = apply_model_inference( model_inference_class=model_inference_class, @@ -95,14 +104,14 @@ def generate_map( ) # Cast to uint8 - if product == "cropland": + if product == WorldCerealProduct.CROPLAND: classes = compress_uint8(classes) else: classes = compress_uint16(classes) classes.execute_batch( outputfile=output_path, - out_format=format, + out_format=out_format, job_options={ "driver-memory": "4g", "executor-memoryOverhead": "12g", From b4cea62c4149d39d1136ed1e1cc53399fb5ee585 Mon Sep 17 00:00:00 2001 From: Darius Couchard Date: Mon, 24 Jun 2024 13:10:19 +0200 Subject: [PATCH 2/4] Added inference result dataclass --- scripts/inference/cropland_mapping.py | 8 +++-- src/worldcereal/job.py | 46 +++++++++++++++++++++++---- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index 4d106a55..5068a184 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -3,10 +3,11 @@ import argparse from pathlib import Path +from loguru import logger from openeo_gfmap import BoundingBoxExtent, TemporalContext from openeo_gfmap.backend import Backend, BackendContext -from worldcereal.job import generate_map +from worldcereal.job import WorldCerealProduct, generate_map if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -50,11 +51,12 @@ backend_context = BackendContext(Backend.FED) - generate_map( + job_results = generate_map( spatial_extent, temporal_extent, backend_context, args.output_path, - product="cropland", + product_type=WorldCerealProduct.CROPLAND, out_format="GTiff", ) + logger.success("Job finished:\n\t%s", job_results) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index fcc6408b..8952d6c6 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -1,7 +1,8 @@ """Executing inference jobs on the OpenEO backend.""" +from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Union +from typing import Optional, Union import openeo from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext @@ -23,12 +24,34 @@ class WorldCerealProduct(Enum): CROPTPE = "croptype" +@dataclass +class WorldCerealResults: + """Dataclass to store the results of the WorldCereal job. + + Attributes + ---------- + job_id : str + Job ID of the finished OpenEO job. + product_url : str + Public URL to the product accessible of the resulting OpenEO job. + output_path : Optional[Path] + Path to the output file, if it was downloaded locally. + product : WorldCerealProduct + Product that was generated. + """ + + job_id: str + product_url: str + output_path: Optional[Path] + product: WorldCerealProduct + + def generate_map( spatial_extent: BoundingBoxExtent, temporal_extent: TemporalContext, backend_context: BackendContext, - output_path: Union[Path, str], - product: WorldCerealProduct = WorldCerealProduct.CROPLAND, + output_path: Optional[Union[Path, str]], + product_type: WorldCerealProduct = WorldCerealProduct.CROPLAND, out_format: str = "GTiff", ): """Main function to generate a WorldCereal product. @@ -78,12 +101,12 @@ def generate_map( ], ) - if product == WorldCerealProduct.CROPLAND: + if product_type == WorldCerealProduct.CROPLAND: # initiate default cropland model model_inference_class = CroplandClassifier model_inference_parameters = {} else: - raise ValueError(f"Product {product} not supported.") + raise ValueError(f"Product {product_type} not supported.") if out_format not in ["GTiff", "NetCDF"]: raise ValueError(f"Format {out_format} not supported.") @@ -104,12 +127,12 @@ def generate_map( ) # Cast to uint8 - if product == WorldCerealProduct.CROPLAND: + if product_type == WorldCerealProduct.CROPLAND: classes = compress_uint8(classes) else: classes = compress_uint16(classes) - classes.execute_batch( + job = classes.execute_batch( outputfile=output_path, out_format=out_format, job_options={ @@ -118,3 +141,12 @@ def generate_map( "udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"], }, ) + # Should contain a single job as this is a single-jon tile inference. + asset = job.get_results().get_assets()[0] + + return WorldCerealResults( + job_id=classes.job_id, + product_url=asset.href, + output_path=output_path, + product=product_type, + ) From 199d55c01837f482e7413bbb59d1374b7bd77002 Mon Sep 17 00:00:00 2001 From: Darius Couchard Date: Mon, 24 Jun 2024 13:12:31 +0200 Subject: [PATCH 3/4] Rename class --- src/worldcereal/job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 8952d6c6..044357c5 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -25,7 +25,7 @@ class WorldCerealProduct(Enum): @dataclass -class WorldCerealResults: +class InferenceResults: """Dataclass to store the results of the WorldCereal job. Attributes @@ -144,7 +144,7 @@ def generate_map( # Should contain a single job as this is a single-jon tile inference. asset = job.get_results().get_assets()[0] - return WorldCerealResults( + return InferenceResults( job_id=classes.job_id, product_url=asset.href, output_path=output_path, From 658aac755c44ce0368f5434185ca163def7bb9e0 Mon Sep 17 00:00:00 2001 From: Darius Couchard Date: Mon, 24 Jun 2024 13:13:33 +0200 Subject: [PATCH 4/4] Renamed typo variable --- src/worldcereal/job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 044357c5..6ffeb071 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -21,7 +21,7 @@ class WorldCerealProduct(Enum): """Enum to define the different WorldCereal products.""" CROPLAND = "cropland" - CROPTPE = "croptype" + CROPTYPE = "croptype" @dataclass