From 5f9627f65476f93e09e57f57f81faeeb3ae67b47 Mon Sep 17 00:00:00 2001 From: Abyot Asalefew Gizaw Date: Wed, 11 Mar 2026 23:33:59 +0100 Subject: [PATCH 1/5] Implement declarative workflow engine with discoverable components --- data/workflows/dhis2_datavalue_set.yaml | 8 + ...alue_set_without_temporal_aggregation.yaml | 7 + src/eo_api/components/__init__.py | 6 + src/eo_api/components/routes.py | 117 +++++ src/eo_api/components/schemas.py | 121 +++++ src/eo_api/components/services.py | 141 +++++ src/eo_api/data_accessor/__init__.py | 5 +- src/eo_api/data_accessor/routes.py | 27 +- src/eo_api/data_accessor/services/__init__.py | 4 +- src/eo_api/data_accessor/services/accessor.py | 27 +- src/eo_api/data_manager/__init__.py | 5 +- src/eo_api/data_manager/routes.py | 16 +- src/eo_api/data_manager/services/__init__.py | 6 +- .../data_manager/services/downloader.py | 19 +- src/eo_api/data_manager/services/utils.py | 3 - src/eo_api/data_registry/__init__.py | 5 +- src/eo_api/data_registry/routes.py | 9 +- src/eo_api/data_registry/services/__init__.py | 4 +- src/eo_api/data_registry/services/datasets.py | 2 +- src/eo_api/main.py | 12 +- src/eo_api/shared/dhis2_adapter.py | 2 +- src/eo_api/shared/time.py | 5 +- src/eo_api/startup.py | 3 +- src/eo_api/system/__init__.py | 5 +- src/eo_api/system/routes.py | 1 - src/eo_api/workflows/__init__.py | 6 + src/eo_api/workflows/routes.py | 41 ++ src/eo_api/workflows/schemas.py | 176 +++++++ src/eo_api/workflows/services/__init__.py | 1 + src/eo_api/workflows/services/datavalueset.py | 65 +++ src/eo_api/workflows/services/definitions.py | 132 +++++ src/eo_api/workflows/services/engine.py | 207 ++++++++ src/eo_api/workflows/services/features.py | 69 +++ src/eo_api/workflows/services/preflight.py | 23 + src/eo_api/workflows/services/run_logs.py | 37 ++ src/eo_api/workflows/services/runtime.py | 89 ++++ .../workflows/services/simple_mapper.py | 86 +++ src/eo_api/workflows/services/spatial.py | 56 ++ src/eo_api/workflows/services/temporal.py | 25 + tests/test_root.py | 2 +- tests/test_workflows.py | 490 ++++++++++++++++++ 41 files changed, 1999 insertions(+), 66 deletions(-) create mode 100644 data/workflows/dhis2_datavalue_set.yaml create mode 100644 data/workflows/dhis2_datavalue_set_without_temporal_aggregation.yaml create mode 100644 src/eo_api/components/__init__.py create mode 100644 src/eo_api/components/routes.py create mode 100644 src/eo_api/components/schemas.py create mode 100644 src/eo_api/components/services.py create mode 100644 src/eo_api/workflows/__init__.py create mode 100644 src/eo_api/workflows/routes.py create mode 100644 src/eo_api/workflows/schemas.py create mode 100644 src/eo_api/workflows/services/__init__.py create mode 100644 src/eo_api/workflows/services/datavalueset.py create mode 100644 src/eo_api/workflows/services/definitions.py create mode 100644 src/eo_api/workflows/services/engine.py create mode 100644 src/eo_api/workflows/services/features.py create mode 100644 src/eo_api/workflows/services/preflight.py create mode 100644 src/eo_api/workflows/services/run_logs.py create mode 100644 src/eo_api/workflows/services/runtime.py create mode 100644 src/eo_api/workflows/services/simple_mapper.py create mode 100644 src/eo_api/workflows/services/spatial.py create mode 100644 src/eo_api/workflows/services/temporal.py create mode 100644 tests/test_workflows.py diff --git a/data/workflows/dhis2_datavalue_set.yaml b/data/workflows/dhis2_datavalue_set.yaml new file mode 100644 index 0000000..180da1d --- /dev/null +++ b/data/workflows/dhis2_datavalue_set.yaml @@ -0,0 +1,8 @@ +workflow_id: dhis2_datavalue_set_v1 +version: 1 +steps: + - component: feature_source + - component: download_dataset + - component: temporal_aggregation + - component: spatial_aggregation + - component: build_datavalueset diff --git a/data/workflows/dhis2_datavalue_set_without_temporal_aggregation.yaml b/data/workflows/dhis2_datavalue_set_without_temporal_aggregation.yaml new file mode 100644 index 0000000..6d1b1f0 --- /dev/null +++ b/data/workflows/dhis2_datavalue_set_without_temporal_aggregation.yaml @@ -0,0 +1,7 @@ +workflow_id: dhis2_datavalue_set_without_temporal_aggregation_v1 +version: 1 +steps: + - component: feature_source + - component: download_dataset + - component: spatial_aggregation + - component: build_datavalueset diff --git a/src/eo_api/components/__init__.py b/src/eo_api/components/__init__.py new file mode 100644 index 0000000..dc357fd --- /dev/null +++ b/src/eo_api/components/__init__.py @@ -0,0 +1,6 @@ +"""Reusable workflow components exposed as API capabilities.""" + +from . import routes as routes +from . import services as services + +__all__ = ["routes", "services"] diff --git a/src/eo_api/components/routes.py b/src/eo_api/components/routes.py new file mode 100644 index 0000000..73519c2 --- /dev/null +++ b/src/eo_api/components/routes.py @@ -0,0 +1,117 @@ +"""Component discovery and execution endpoints.""" + +from __future__ import annotations + +from fastapi import APIRouter + +from ..data_manager.services.constants import BBOX +from . import services +from .schemas import ( + BuildDataValueSetRunRequest, + BuildDataValueSetRunResponse, + ComponentCatalogResponse, + DownloadDatasetRunRequest, + DownloadDatasetRunResponse, + FeatureSourceRunRequest, + FeatureSourceRunResponse, + SpatialAggregationRunRequest, + SpatialAggregationRunResponse, + TemporalAggregationRunRequest, + TemporalAggregationRunResponse, +) + +router = APIRouter() + + +@router.get("/components", response_model=ComponentCatalogResponse) +def list_components() -> ComponentCatalogResponse: + """List all discoverable reusable components.""" + return ComponentCatalogResponse(components=services.component_catalog()) + + +@router.post("/components/feature-source", response_model=FeatureSourceRunResponse) +def run_feature_source(payload: FeatureSourceRunRequest) -> FeatureSourceRunResponse: + """Resolve feature source to features and bbox.""" + features, bbox = services.feature_source_component(payload.feature_source) + return FeatureSourceRunResponse( + bbox=bbox, + feature_count=len(features["features"]), + features=features if payload.include_features else None, + ) + + +@router.post("/components/download-dataset", response_model=DownloadDatasetRunResponse) +def run_download_dataset(payload: DownloadDatasetRunRequest) -> DownloadDatasetRunResponse: + """Download dataset files for the selected period/scope.""" + dataset = services.require_dataset(payload.dataset_id) + bbox = payload.bbox or BBOX + services.download_dataset_component( + dataset=dataset, + start=payload.start, + end=payload.end, + overwrite=payload.overwrite, + country_code=payload.country_code, + bbox=bbox, + ) + return DownloadDatasetRunResponse( + status="completed", + dataset_id=payload.dataset_id, + start=payload.start, + end=payload.end, + ) + + +@router.post("/components/temporal-aggregation", response_model=TemporalAggregationRunResponse) +def run_temporal_aggregation(payload: TemporalAggregationRunRequest) -> TemporalAggregationRunResponse: + """Aggregate a dataset temporally.""" + dataset = services.require_dataset(payload.dataset_id) + ds = services.temporal_aggregation_component( + dataset=dataset, + start=payload.start, + end=payload.end, + bbox=payload.bbox, + target_period_type=payload.target_period_type, + method=payload.method, + ) + return TemporalAggregationRunResponse( + dataset_id=payload.dataset_id, + sizes={str(k): int(v) for k, v in ds.sizes.items()}, + dims=[str(d) for d in ds.dims], + ) + + +@router.post("/components/spatial-aggregation", response_model=SpatialAggregationRunResponse) +def run_spatial_aggregation(payload: SpatialAggregationRunRequest) -> SpatialAggregationRunResponse: + """Aggregate a dataset spatially to features.""" + dataset = services.require_dataset(payload.dataset_id) + features, bbox = services.feature_source_component(payload.feature_source) + records = services.spatial_aggregation_component( + dataset=dataset, + start=payload.start, + end=payload.end, + bbox=payload.bbox or bbox, + features=features, + method=payload.method, + feature_id_property=payload.feature_id_property, + ) + return SpatialAggregationRunResponse( + dataset_id=payload.dataset_id, + record_count=len(records), + preview=records[: payload.max_preview_rows], + ) + + +@router.post("/components/build-datavalue-set", response_model=BuildDataValueSetRunResponse) +def run_build_datavalueset(payload: BuildDataValueSetRunRequest) -> BuildDataValueSetRunResponse: + """Build and serialize a DHIS2 DataValueSet from records.""" + data_value_set, output_file = services.build_datavalueset_component( + dataset_id=payload.dataset_id, + period_type=payload.period_type, + records=payload.records, + dhis2=payload.dhis2, + ) + return BuildDataValueSetRunResponse( + value_count=len(data_value_set.get("dataValues", [])), + output_file=output_file, + data_value_set=data_value_set, + ) diff --git a/src/eo_api/components/schemas.py b/src/eo_api/components/schemas.py new file mode 100644 index 0000000..80c772e --- /dev/null +++ b/src/eo_api/components/schemas.py @@ -0,0 +1,121 @@ +"""Schemas for component discovery and execution endpoints.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + +from ..workflows.schemas import ( + AggregationMethod, + Dhis2DataValueSetConfig, + FeatureSourceConfig, + PeriodType, +) + + +class ComponentDefinition(BaseModel): + """Component metadata for discovery.""" + + name: str + description: str + inputs: list[str] + outputs: list[str] + + +class ComponentCatalogResponse(BaseModel): + """List of discoverable components.""" + + components: list[ComponentDefinition] + + +class FeatureSourceRunRequest(BaseModel): + """Execute feature source component.""" + + feature_source: FeatureSourceConfig + include_features: bool = False + + +class FeatureSourceRunResponse(BaseModel): + """Feature source component result.""" + + bbox: list[float] + feature_count: int + features: dict[str, Any] | None = None + + +class DownloadDatasetRunRequest(BaseModel): + """Execute dataset download component.""" + + dataset_id: str + start: str + end: str + overwrite: bool = False + country_code: str | None = None + bbox: list[float] | None = None + + +class DownloadDatasetRunResponse(BaseModel): + """Download component result.""" + + status: str + dataset_id: str + start: str + end: str + + +class TemporalAggregationRunRequest(BaseModel): + """Execute temporal aggregation component from cached dataset.""" + + dataset_id: str + start: str + end: str + target_period_type: PeriodType + method: AggregationMethod = AggregationMethod.SUM + bbox: list[float] | None = None + + +class TemporalAggregationRunResponse(BaseModel): + """Temporal aggregation result summary.""" + + dataset_id: str + sizes: dict[str, int] + dims: list[str] + + +class SpatialAggregationRunRequest(BaseModel): + """Execute spatial aggregation component from cached dataset.""" + + dataset_id: str + start: str + end: str + feature_source: FeatureSourceConfig + method: AggregationMethod = AggregationMethod.MEAN + bbox: list[float] | None = None + feature_id_property: str = "id" + max_preview_rows: int = 20 + + +class SpatialAggregationRunResponse(BaseModel): + """Spatial aggregation result with sample rows.""" + + dataset_id: str + record_count: int + preview: list[dict[str, Any]] + + +class BuildDataValueSetRunRequest(BaseModel): + """Execute build_datavalueset component directly from records.""" + + dataset_id: str + period_type: PeriodType + records: list[dict[str, Any]] = Field(default_factory=list) + dhis2: Dhis2DataValueSetConfig + + +class BuildDataValueSetRunResponse(BaseModel): + """Build_datavalueset component output.""" + + value_count: int + output_file: str + data_value_set: dict[str, Any] diff --git a/src/eo_api/components/services.py b/src/eo_api/components/services.py new file mode 100644 index 0000000..2e99133 --- /dev/null +++ b/src/eo_api/components/services.py @@ -0,0 +1,141 @@ +"""Component service implementations and discovery metadata.""" + +from __future__ import annotations + +from typing import Any + +import xarray as xr +from fastapi import HTTPException + +from ..data_accessor.services.accessor import get_data +from ..data_manager.services import downloader +from ..data_registry.services.datasets import get_dataset +from ..workflows.schemas import ( + AggregationMethod, + Dhis2DataValueSetConfig, + FeatureSourceConfig, + PeriodType, +) +from ..workflows.services.datavalueset import build_data_value_set +from ..workflows.services.features import resolve_features +from ..workflows.services.preflight import check_upstream_connectivity +from ..workflows.services.spatial import aggregate_to_features +from ..workflows.services.temporal import aggregate_temporal +from .schemas import ComponentDefinition + + +def component_catalog() -> list[ComponentDefinition]: + """Return all discoverable component definitions.""" + return [ + ComponentDefinition( + name="feature_source", + description="Resolve feature source and compute bbox.", + inputs=["feature_source"], + outputs=["features", "bbox"], + ), + ComponentDefinition( + name="download_dataset", + description="Download dataset files for period and bbox.", + inputs=["dataset_id", "start", "end", "overwrite", "country_code", "bbox"], + outputs=["status"], + ), + ComponentDefinition( + name="temporal_aggregation", + description="Aggregate dataset over time dimension.", + inputs=["dataset_id", "start", "end", "target_period_type", "method", "bbox"], + outputs=["dataset"], + ), + ComponentDefinition( + name="spatial_aggregation", + description="Aggregate gridded dataset to features.", + inputs=["dataset_id", "start", "end", "feature_source", "method"], + outputs=["records"], + ), + ComponentDefinition( + name="build_datavalueset", + description="Build and serialize DHIS2 DataValueSet JSON.", + inputs=["dataset_id", "period_type", "records", "dhis2"], + outputs=["data_value_set", "output_file"], + ), + ] + + +def feature_source_component(config: FeatureSourceConfig) -> tuple[dict[str, Any], list[float]]: + """Run feature source component.""" + return resolve_features(config) + + +def download_dataset_component( + *, + dataset: dict[str, Any], + start: str, + end: str, + overwrite: bool, + country_code: str | None, + bbox: list[float], +) -> None: + """Run connectivity preflight and download dataset files.""" + check_upstream_connectivity(dataset) + downloader.download_dataset( + dataset=dataset, + start=start, + end=end, + overwrite=overwrite, + background_tasks=None, + country_code=country_code, + bbox=bbox, + ) + + +def temporal_aggregation_component( + *, + dataset: dict[str, Any], + start: str, + end: str, + bbox: list[float] | None, + target_period_type: PeriodType, + method: AggregationMethod, +) -> xr.Dataset: + """Load dataset and aggregate over time.""" + ds = get_data(dataset=dataset, start=start, end=end, bbox=bbox) + return aggregate_temporal(ds=ds, period_type=target_period_type, method=method) + + +def spatial_aggregation_component( + *, + dataset: dict[str, Any], + start: str, + end: str, + bbox: list[float] | None, + features: dict[str, Any], + method: AggregationMethod, + feature_id_property: str, +) -> list[dict[str, Any]]: + """Load dataset and aggregate spatially to provided features.""" + ds = get_data(dataset=dataset, start=start, end=end, bbox=bbox) + return aggregate_to_features( + ds=ds, + variable=dataset["variable"], + features=features, + method=method.value, + feature_id_property=feature_id_property, + ) + + +def build_datavalueset_component( + *, + dataset_id: str, + period_type: PeriodType, + records: list[dict[str, Any]], + dhis2: Dhis2DataValueSetConfig, +) -> tuple[dict[str, Any], str]: + """Build and serialize DHIS2 DataValueSet from records.""" + return build_data_value_set(records=records, dataset_id=dataset_id, period_type=period_type, config=dhis2) + + +def require_dataset(dataset_id: str) -> dict[str, Any]: + """Resolve dataset or raise 404.""" + dataset = get_dataset(dataset_id) + if dataset is None: + raise HTTPException(status_code=404, detail=f"Dataset '{dataset_id}' not found") + return dataset diff --git a/src/eo_api/data_accessor/__init__.py b/src/eo_api/data_accessor/__init__.py index 0fbcaa5..ee5067d 100644 --- a/src/eo_api/data_accessor/__init__.py +++ b/src/eo_api/data_accessor/__init__.py @@ -1 +1,4 @@ -from . import routes, services \ No newline at end of file +from . import routes as routes +from . import services as services + +__all__ = ["routes", "services"] diff --git a/src/eo_api/data_accessor/routes.py b/src/eo_api/data_accessor/routes.py index 8e3a6f5..507b27d 100644 --- a/src/eo_api/data_accessor/routes.py +++ b/src/eo_api/data_accessor/routes.py @@ -1,45 +1,44 @@ """FastAPI router exposing dataset endpoints.""" -from typing import Any - -import xarray as xr -from fastapi import APIRouter, BackgroundTasks, HTTPException, Response +from fastapi import APIRouter from fastapi.responses import FileResponse from starlette.background import BackgroundTask -from .services.accessor import cleanup_file, get_data, xarray_to_temporary_netcdf from ..data_registry.routes import _get_dataset_or_404 +from .services.accessor import cleanup_file, get_data, xarray_to_temporary_netcdf router = APIRouter() + @router.get("/{dataset_id}") def get_file( dataset_id: str, start: str, end: str, - xmin: float = None, - ymin: float = None, - xmax: float = None, - ymax: float = None, - format: str = 'netcdf', + xmin: float | None = None, + ymin: float | None = None, + xmax: float | None = None, + ymax: float | None = None, + format: str = "netcdf", ) -> FileResponse: """Get a dataset filtered to a timeperiod and bbox as a downloadable raster file.""" dataset = _get_dataset_or_404(dataset_id) # get filtered data - if all([xmin, ymin, xmax, ymax]): + bbox: list[float] | None + if xmin is not None and ymin is not None and xmax is not None and ymax is not None: bbox = [xmin, ymin, xmax, ymax] else: bbox = None ds = get_data(dataset, start, end, bbox) # save to temporary file - if format.lower() == 'netcdf': + if format.lower() == "netcdf": # convert to netcdf file_path = xarray_to_temporary_netcdf(ds) else: - raise ValueError(f'Unsupported output format: {format}') + raise ValueError(f"Unsupported output format: {format}") # return as file return FileResponse( @@ -47,4 +46,4 @@ def get_file( media_type="application/x-netcdf", filename="eo-api-raster-download.nc", background=BackgroundTask(cleanup_file, file_path), - ) \ No newline at end of file + ) diff --git a/src/eo_api/data_accessor/services/__init__.py b/src/eo_api/data_accessor/services/__init__.py index 512ee18..209f276 100644 --- a/src/eo_api/data_accessor/services/__init__.py +++ b/src/eo_api/data_accessor/services/__init__.py @@ -1 +1,3 @@ -from . import accessor \ No newline at end of file +from . import accessor as accessor + +__all__ = ["accessor"] diff --git a/src/eo_api/data_accessor/services/accessor.py b/src/eo_api/data_accessor/services/accessor.py index d254a94..b9b78e9 100644 --- a/src/eo_api/data_accessor/services/accessor.py +++ b/src/eo_api/data_accessor/services/accessor.py @@ -1,25 +1,30 @@ """Loading raster data from downloaded files into xarray.""" -import os -import json import logging +import os import tempfile from typing import Any import xarray as xr from ...data_manager.services.downloader import get_cache_files, get_zarr_path -from ...data_manager.services.utils import get_time_dim, get_lon_lat_dims +from ...data_manager.services.utils import get_lon_lat_dims, get_time_dim from ...shared.time import numpy_datetime_to_period_string logger = logging.getLogger(__name__) -def get_data(dataset: dict[str, Any], start: str = None, end: str = None, bbox: list = None) -> xr.Dataset: + +def get_data( + dataset: dict[str, Any], + start: str | None = None, + end: str | None = None, + bbox: list[float] | None = None, +) -> xr.Dataset: """Load an xarray raster dataset for a given time range and bbox.""" logger.info("Opening dataset") zarr_path = get_zarr_path(dataset) if zarr_path: - logger.info(f'Using optimized zarr file: {zarr_path}') + logger.info(f"Using optimized zarr file: {zarr_path}") ds = xr.open_zarr(zarr_path, consolidated=True) else: logger.warning( @@ -40,8 +45,8 @@ def get_data(dataset: dict[str, Any], start: str = None, end: str = None, bbox: if bbox is not None: logger.info(f"Subsetting xy to {bbox}") - xmin,ymin,xmax,ymax = list(map(float, bbox)) - lon_dim,lat_dim = get_lon_lat_dims(ds) + xmin, ymin, xmax, ymax = list(map(float, bbox)) + lon_dim, lat_dim = get_lon_lat_dims(ds) # TODO: this assumes y axis increases towards north and is not very stable # ...and also does not consider partial pixels at the edges # ...should probably switch to rioxarray.clip instead @@ -49,9 +54,9 @@ def get_data(dataset: dict[str, Any], start: str = None, end: str = None, bbox: return ds # type: ignore[no-any-return] + def get_data_coverage(dataset: dict[str, Any]) -> dict[str, Any]: """Return temporal and spatial coverage metadata for downloaded data.""" - ds = get_data(dataset) if not ds: @@ -60,8 +65,8 @@ def get_data_coverage(dataset: dict[str, Any]) -> dict[str, Any]: time_dim = get_time_dim(ds) lon_dim, lat_dim = get_lon_lat_dims(ds) - start = numpy_datetime_to_period_string(ds[time_dim].min(), dataset['period_type']) # type: ignore[arg-type] - end = numpy_datetime_to_period_string(ds[time_dim].max(), dataset['period_type']) # type: ignore[arg-type] + start = numpy_datetime_to_period_string(ds[time_dim].min(), dataset["period_type"]) # type: ignore[arg-type] + end = numpy_datetime_to_period_string(ds[time_dim].max(), dataset["period_type"]) # type: ignore[arg-type] xmin, xmax = ds[lon_dim].min().item(), ds[lon_dim].max().item() ymin, ymax = ds[lat_dim].min().item(), ds[lat_dim].max().item() @@ -73,6 +78,7 @@ def get_data_coverage(dataset: dict[str, Any]) -> dict[str, Any]: } } + def xarray_to_temporary_netcdf(ds: xr.Dataset) -> str: """Write a dataset to a temporary NetCDF file and return the path.""" fd = tempfile.NamedTemporaryFile(suffix=".nc", delete=False) @@ -81,6 +87,7 @@ def xarray_to_temporary_netcdf(ds: xr.Dataset) -> str: ds.to_netcdf(path) return path + def cleanup_file(path: str) -> None: """Remove a file from disk.""" os.remove(path) diff --git a/src/eo_api/data_manager/__init__.py b/src/eo_api/data_manager/__init__.py index 0fbcaa5..ee5067d 100644 --- a/src/eo_api/data_manager/__init__.py +++ b/src/eo_api/data_manager/__init__.py @@ -1 +1,4 @@ -from . import routes, services \ No newline at end of file +from . import routes as routes +from . import services as services + +__all__ = ["routes", "services"] diff --git a/src/eo_api/data_manager/routes.py b/src/eo_api/data_manager/routes.py index c6c5679..f369dbe 100644 --- a/src/eo_api/data_manager/routes.py +++ b/src/eo_api/data_manager/routes.py @@ -1,14 +1,9 @@ """FastAPI router exposing dataset endpoints.""" -from typing import Any +from fastapi import APIRouter, BackgroundTasks -import xarray as xr -from fastapi import APIRouter, BackgroundTasks, HTTPException, Response -from fastapi.responses import FileResponse -from starlette.background import BackgroundTask - -from .services import constants, downloader from ..data_registry.routes import _get_dataset_or_404 +from .services import downloader router = APIRouter() @@ -17,9 +12,9 @@ def download_dataset( dataset_id: str, start: str, + background_tasks: BackgroundTasks, end: str | None = None, overwrite: bool = False, - background_tasks: BackgroundTasks = None, ) -> dict[str, str]: """Download dataset as local netcdf files direct from the source.""" dataset = _get_dataset_or_404(dataset_id) @@ -30,10 +25,9 @@ def download_dataset( @router.get("/{dataset_id}/build_zarr", response_model=dict) def build_dataset_zarr( dataset_id: str, - background_tasks: BackgroundTasks = None, + background_tasks: BackgroundTasks, ) -> dict[str, str]: """Optimize dataset downloads by collecting all files to a single zarr archive.""" dataset = _get_dataset_or_404(dataset_id) - if background_tasks is not None: - background_tasks.add_task(downloader.build_dataset_zarr, dataset) + background_tasks.add_task(downloader.build_dataset_zarr, dataset) return {"status": "Building zarr file from dataset downloads"} diff --git a/src/eo_api/data_manager/services/__init__.py b/src/eo_api/data_manager/services/__init__.py index 2ba6614..a20a096 100644 --- a/src/eo_api/data_manager/services/__init__.py +++ b/src/eo_api/data_manager/services/__init__.py @@ -1 +1,5 @@ -from . import constants, downloader, utils \ No newline at end of file +from . import constants as constants +from . import downloader as downloader +from . import utils as utils + +__all__ = ["constants", "downloader", "utils"] diff --git a/src/eo_api/data_manager/services/downloader.py b/src/eo_api/data_manager/services/downloader.py index 93c8a27..c74c06c 100644 --- a/src/eo_api/data_manager/services/downloader.py +++ b/src/eo_api/data_manager/services/downloader.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) SCRIPT_DIR = Path(__file__).parent.resolve() -_download_dir = SCRIPT_DIR.parent.parent.parent.parent / 'data' / 'downloads' +_download_dir = SCRIPT_DIR.parent.parent.parent.parent / "data" / "downloads" if CACHE_OVERRIDE: _download_dir = Path(CACHE_OVERRIDE) DOWNLOAD_DIR = _download_dir @@ -29,6 +29,8 @@ def download_dataset( end: str | None, overwrite: bool, background_tasks: BackgroundTasks | None, + country_code: str | None = None, + bbox: list[float] | None = None, ) -> None: """Download dataset from source and store as local NetCDF cache files.""" cache_info = dataset["cache_info"] @@ -48,15 +50,22 @@ def download_dataset( sig = inspect.signature(eo_download_func) if "bbox" in sig.parameters: - params["bbox"] = BBOX + params["bbox"] = bbox or BBOX elif "country_code" in sig.parameters: - if COUNTRY_CODE: - params["country_code"] = COUNTRY_CODE + resolved_country_code = country_code or COUNTRY_CODE + if resolved_country_code: + params["country_code"] = resolved_country_code else: - raise Exception('Downloading WorldPop data requires COUNTRY_CODE environment variable') + raise Exception( + "Downloading WorldPop data requires country_code input (or COUNTRY_CODE environment variable)" + ) + + DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True) if background_tasks is not None: background_tasks.add_task(eo_download_func, **params) + else: + eo_download_func(**params) def build_dataset_zarr(dataset: dict[str, Any]) -> None: diff --git a/src/eo_api/data_manager/services/utils.py b/src/eo_api/data_manager/services/utils.py index aa797cc..7e3bc0c 100644 --- a/src/eo_api/data_manager/services/utils.py +++ b/src/eo_api/data_manager/services/utils.py @@ -2,9 +2,6 @@ from typing import Any -import numpy as np -import pandas as pd - def get_time_dim(ds: Any) -> str: """Return the name of the time dimension in a dataset or dataframe.""" diff --git a/src/eo_api/data_registry/__init__.py b/src/eo_api/data_registry/__init__.py index 0fbcaa5..ee5067d 100644 --- a/src/eo_api/data_registry/__init__.py +++ b/src/eo_api/data_registry/__init__.py @@ -1 +1,4 @@ -from . import routes, services \ No newline at end of file +from . import routes as routes +from . import services as services + +__all__ = ["routes", "services"] diff --git a/src/eo_api/data_registry/routes.py b/src/eo_api/data_registry/routes.py index ffa306d..36f4fce 100644 --- a/src/eo_api/data_registry/routes.py +++ b/src/eo_api/data_registry/routes.py @@ -2,15 +2,13 @@ from typing import Any -import xarray as xr -from fastapi import APIRouter, BackgroundTasks, HTTPException, Response -from fastapi.responses import FileResponse -from starlette.background import BackgroundTask +from fastapi import APIRouter, HTTPException from .services import datasets router = APIRouter() + @router.get("/") def list_datasets() -> list[dict[str, Any]]: """Return list of available datasets from registry.""" @@ -30,7 +28,8 @@ def get_dataset(dataset_id: str) -> dict[str, Any]: """Get a single dataset by ID.""" # Note: have to import inside function to avoid circular import from ..data_accessor.services.accessor import get_data_coverage + dataset = _get_dataset_or_404(dataset_id) coverage = get_data_coverage(dataset) dataset.update(coverage) - return dataset \ No newline at end of file + return dataset diff --git a/src/eo_api/data_registry/services/__init__.py b/src/eo_api/data_registry/services/__init__.py index 08014f6..9d0231a 100644 --- a/src/eo_api/data_registry/services/__init__.py +++ b/src/eo_api/data_registry/services/__init__.py @@ -1 +1,3 @@ -from . import datasets \ No newline at end of file +from . import datasets as datasets + +__all__ = ["datasets"] diff --git a/src/eo_api/data_registry/services/datasets.py b/src/eo_api/data_registry/services/datasets.py index 371e8fc..9bcc5de 100644 --- a/src/eo_api/data_registry/services/datasets.py +++ b/src/eo_api/data_registry/services/datasets.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) SCRIPT_DIR = Path(__file__).parent.resolve() -CONFIGS_DIR = SCRIPT_DIR.parent.parent.parent.parent / 'data' / 'datasets' +CONFIGS_DIR = SCRIPT_DIR.parent.parent.parent.parent / "data" / "datasets" def list_datasets() -> list[dict[str, Any]]: diff --git a/src/eo_api/main.py b/src/eo_api/main.py index e12ab58..0ab8abb 100644 --- a/src/eo_api/main.py +++ b/src/eo_api/main.py @@ -4,7 +4,7 @@ from fastapi.middleware.cors import CORSMiddleware import eo_api.startup # noqa: F401 # pyright: ignore[reportUnusedImport] -from eo_api import data_accessor, data_manager, data_registry, system +from eo_api import components, data_accessor, data_manager, data_registry, system, workflows app = FastAPI() @@ -16,7 +16,9 @@ allow_headers=["*"], ) -app.include_router(system.routes.router, tags=['System']) -app.include_router(data_registry.routes.router, prefix='/registry', tags=['Data registry']) -app.include_router(data_manager.routes.router, prefix='/manage', tags=['Data manager']) -app.include_router(data_accessor.routes.router, prefix='/retrieve', tags=['Data retrieval']) +app.include_router(system.routes.router, tags=["System"]) +app.include_router(data_registry.routes.router, prefix="/registry", tags=["Data registry"]) +app.include_router(data_manager.routes.router, prefix="/manage", tags=["Data manager"]) +app.include_router(data_accessor.routes.router, prefix="/retrieve", tags=["Data retrieval"]) +app.include_router(workflows.routes.router, prefix="/workflows", tags=["Workflows"]) +app.include_router(components.routes.router, tags=["Components"]) diff --git a/src/eo_api/shared/dhis2_adapter.py b/src/eo_api/shared/dhis2_adapter.py index 0e9e5b1..ea54348 100644 --- a/src/eo_api/shared/dhis2_adapter.py +++ b/src/eo_api/shared/dhis2_adapter.py @@ -86,4 +86,4 @@ def get_org_unit_geojson(client: DHIS2Client, uid: str) -> dict[str, Any]: def get_org_unit_subtree_geojson(client: DHIS2Client, uid: str) -> dict[str, Any]: """Fetch a subtree of organisation units as GeoJSON.""" - return cast(dict[str, Any], client.get_org_unit_subtree_geojson(uid)) \ No newline at end of file + return cast(dict[str, Any], client.get_org_unit_subtree_geojson(uid)) diff --git a/src/eo_api/shared/time.py b/src/eo_api/shared/time.py index 269740f..5690a11 100644 --- a/src/eo_api/shared/time.py +++ b/src/eo_api/shared/time.py @@ -1,7 +1,10 @@ +"""Shared time conversion helpers.""" + from typing import Any import numpy as np + def numpy_datetime_to_period_string(datetimes: np.ndarray[Any, Any], period_type: str) -> np.ndarray[Any, Any]: """Convert an array of numpy datetimes to truncated period strings.""" # TODO: this and numpy_period_string should be merged @@ -9,4 +12,4 @@ def numpy_datetime_to_period_string(datetimes: np.ndarray[Any, Any], period_type # Map periods to string lengths: YYYY-MM-DDTHH (13), YYYY-MM-DD (10), etc. lengths = {"hourly": 13, "daily": 10, "monthly": 7, "yearly": 4} - return s.astype(f"U{lengths[period_type]}") \ No newline at end of file + return s.astype(f"U{lengths[period_type]}") diff --git a/src/eo_api/startup.py b/src/eo_api/startup.py index 1d1ba51..5c33ffa 100644 --- a/src/eo_api/startup.py +++ b/src/eo_api/startup.py @@ -1,10 +1,11 @@ -"""Early-boot side effects +"""Early-boot side effects. This module is imported before any other eo_api modules so that environment variables and logging are configured before other imports. """ import logging + from dotenv import load_dotenv # noqa: E402 # -- Load .env (must happen before pygeoapi reads PYGEOAPI_CONFIG) ------------ diff --git a/src/eo_api/system/__init__.py b/src/eo_api/system/__init__.py index 00e7846..865a013 100644 --- a/src/eo_api/system/__init__.py +++ b/src/eo_api/system/__init__.py @@ -1 +1,4 @@ -from . import routes, schemas \ No newline at end of file +from . import routes as routes +from . import schemas as schemas + +__all__ = ["routes", "schemas"] diff --git a/src/eo_api/system/routes.py b/src/eo_api/system/routes.py index 2639ea3..1c41b1f 100644 --- a/src/eo_api/system/routes.py +++ b/src/eo_api/system/routes.py @@ -4,7 +4,6 @@ from importlib.metadata import version from fastapi import APIRouter, Request -from fastapi.responses import RedirectResponse from .schemas import AppInfo, HealthStatus, Link, RootResponse, Status diff --git a/src/eo_api/workflows/__init__.py b/src/eo_api/workflows/__init__.py new file mode 100644 index 0000000..23f4ad6 --- /dev/null +++ b/src/eo_api/workflows/__init__.py @@ -0,0 +1,6 @@ +"""Workflow APIs for generic gridded-data to DHIS2 pipelines.""" + +from . import routes as routes +from . import services as services + +__all__ = ["routes", "services"] diff --git a/src/eo_api/workflows/routes.py b/src/eo_api/workflows/routes.py new file mode 100644 index 0000000..da9d271 --- /dev/null +++ b/src/eo_api/workflows/routes.py @@ -0,0 +1,41 @@ +"""API routes for workflow discovery and execution.""" + +from fastapi import APIRouter, HTTPException + +from .schemas import WorkflowCatalogItem, WorkflowCatalogResponse, WorkflowExecuteResponse, WorkflowRequest +from .services.definitions import list_workflow_definitions +from .services.engine import execute_workflow +from .services.simple_mapper import normalize_simple_request + +router = APIRouter() + + +@router.get("", response_model=WorkflowCatalogResponse) +def list_workflows() -> WorkflowCatalogResponse: + """List all allowlisted workflow definitions.""" + try: + definitions = list_workflow_definitions() + except ValueError as exc: + raise HTTPException(status_code=500, detail=str(exc)) from exc + return WorkflowCatalogResponse( + workflows=[ + WorkflowCatalogItem( + workflow_id=definition.workflow_id, + version=definition.version, + step_count=len(definition.steps), + components=[step.component for step in definition.steps], + ) + for definition in definitions + ] + ) + + +@router.post("/dhis2-datavalue-set", response_model=WorkflowExecuteResponse) +def run_dhis2_datavalue_set_workflow(payload: WorkflowRequest) -> WorkflowExecuteResponse: + """Run workflow from a single flat request payload.""" + request, _warnings = normalize_simple_request(payload) + return execute_workflow( + request, + workflow_id=payload.workflow_id, + include_component_run_details=payload.include_component_run_details, + ) diff --git a/src/eo_api/workflows/schemas.py b/src/eo_api/workflows/schemas.py new file mode 100644 index 0000000..7e48c5f --- /dev/null +++ b/src/eo_api/workflows/schemas.py @@ -0,0 +1,176 @@ +"""Schemas for generic DHIS2 workflow execution.""" + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field, model_validator + + +class FeatureSourceType(StrEnum): + """Supported feature source backends.""" + + GEOJSON_FILE = "geojson_file" + DHIS2_LEVEL = "dhis2_level" + DHIS2_IDS = "dhis2_ids" + + +class AggregationMethod(StrEnum): + """Supported numeric aggregation methods.""" + + MEAN = "mean" + SUM = "sum" + MIN = "min" + MAX = "max" + + +class PeriodType(StrEnum): + """Supported temporal period types.""" + + HOURLY = "hourly" + DAILY = "daily" + MONTHLY = "monthly" + YEARLY = "yearly" + + +class FeatureSourceConfig(BaseModel): + """How to fetch features for spatial aggregation.""" + + source_type: FeatureSourceType + geojson_path: str | None = None + dhis2_level: int | None = None + dhis2_ids: list[str] | None = None + dhis2_parent: str | None = None + feature_id_property: str = "id" + + @model_validator(mode="after") + def validate_by_source(self) -> "FeatureSourceConfig": + """Enforce required fields per source backend.""" + if self.source_type == FeatureSourceType.GEOJSON_FILE and not self.geojson_path: + raise ValueError("geojson_path is required when source_type='geojson_file'") + if self.source_type == FeatureSourceType.DHIS2_LEVEL and self.dhis2_level is None: + raise ValueError("dhis2_level is required when source_type='dhis2_level'") + if self.source_type == FeatureSourceType.DHIS2_IDS and not self.dhis2_ids: + raise ValueError("dhis2_ids is required when source_type='dhis2_ids'") + return self + + +class TemporalAggregationConfig(BaseModel): + """Temporal rollup config.""" + + target_period_type: PeriodType + method: AggregationMethod = AggregationMethod.SUM + + +class SpatialAggregationConfig(BaseModel): + """Spatial aggregation config.""" + + method: AggregationMethod = AggregationMethod.MEAN + + +class Dhis2DataValueSetConfig(BaseModel): + """Mapping from aggregate outputs to DHIS2 DataValueSet fields.""" + + data_element_uid: str + category_option_combo_uid: str = "HllvX50cXC0" + attribute_option_combo_uid: str = "HllvX50cXC0" + data_set_uid: str | None = None + org_unit_property: str = "id" + stored_by: str | None = None + + +class WorkflowExecuteRequest(BaseModel): + """End-to-end workflow request.""" + + dataset_id: str + start: str + end: str + overwrite: bool = False + country_code: str | None = None + feature_source: FeatureSourceConfig + temporal_aggregation: TemporalAggregationConfig + spatial_aggregation: SpatialAggregationConfig = Field(default_factory=SpatialAggregationConfig) + dhis2: Dhis2DataValueSetConfig + + +class ComponentRun(BaseModel): + """Execution metadata for one workflow component.""" + + component: str + status: str + started_at: str + ended_at: str + duration_ms: int + inputs: dict[str, Any] + outputs: dict[str, Any] | None = None + error: str | None = None + + +class WorkflowExecuteResponse(BaseModel): + """Workflow execution response.""" + + status: str + run_id: str + workflow_id: str + workflow_version: int + dataset_id: str + bbox: list[float] + feature_count: int + value_count: int + output_file: str + run_log_file: str + data_value_set: dict[str, Any] + component_runs: list[ComponentRun] + component_run_details_included: bool = False + component_run_details_available: bool = True + + +class WorkflowCatalogItem(BaseModel): + """Discoverable workflow definition summary.""" + + workflow_id: str + version: int + step_count: int + components: list[str] + + +class WorkflowCatalogResponse(BaseModel): + """List of allowlisted workflow definitions.""" + + workflows: list[WorkflowCatalogItem] + + +class WorkflowRequest(BaseModel): + """Public flat workflow request payload.""" + + workflow_id: str = "dhis2_datavalue_set_v1" + dataset_id: str + start_date: str | None = None + end_date: str | None = None + start_year: int | None = None + end_year: int | None = None + org_unit_level: int | None = None + org_unit_ids: list[str] | None = None + data_element: str + temporal_resolution: PeriodType = PeriodType.MONTHLY + temporal_reducer: AggregationMethod = AggregationMethod.SUM + spatial_reducer: AggregationMethod = AggregationMethod.MEAN + overwrite: bool = False + dry_run: bool = True + feature_id_property: str = "id" + stage: str | None = None + flavor: str | None = None + country_code: str | None = None + output_format: str | None = None + reducer: str | None = None + include_component_run_details: bool = False + + @model_validator(mode="after") + def validate_time_window(self) -> "WorkflowRequest": + """Require either date range or year range.""" + has_dates = bool(self.start_date and self.end_date) + has_years = self.start_year is not None and self.end_year is not None + if not has_dates and not has_years: + raise ValueError("Provide either start_date/end_date or start_year/end_year") + if self.org_unit_level is None and not self.org_unit_ids: + raise ValueError("Provide org_unit_level or org_unit_ids") + return self diff --git a/src/eo_api/workflows/services/__init__.py b/src/eo_api/workflows/services/__init__.py new file mode 100644 index 0000000..4e6122f --- /dev/null +++ b/src/eo_api/workflows/services/__init__.py @@ -0,0 +1 @@ +"""Workflow service components.""" diff --git a/src/eo_api/workflows/services/datavalueset.py b/src/eo_api/workflows/services/datavalueset.py new file mode 100644 index 0000000..3768872 --- /dev/null +++ b/src/eo_api/workflows/services/datavalueset.py @@ -0,0 +1,65 @@ +"""DHIS2 DataValueSet builder component.""" + +from __future__ import annotations + +import datetime as dt +import json +from typing import Any + +import numpy as np + +from ...data_manager.services.downloader import DOWNLOAD_DIR +from ..schemas import Dhis2DataValueSetConfig, PeriodType + + +def build_data_value_set( + records: list[dict[str, Any]], + *, + dataset_id: str, + period_type: PeriodType, + config: Dhis2DataValueSetConfig, +) -> tuple[dict[str, Any], str]: + """Build and serialize a DHIS2-compatible DataValueSet JSON payload.""" + data_values: list[dict[str, Any]] = [] + for record in records: + period = _format_period(record["time"], period_type) + data_values.append( + { + "dataElement": config.data_element_uid, + "period": period, + "orgUnit": record["org_unit"], + "categoryOptionCombo": config.category_option_combo_uid, + "attributeOptionCombo": config.attribute_option_combo_uid, + "value": str(record["value"]), + } + ) + + payload: dict[str, Any] = {"dataValues": data_values} + if config.data_set_uid: + payload["dataSet"] = config.data_set_uid + if config.stored_by: + payload["storedBy"] = config.stored_by + output_file = _write_data_value_set(payload, dataset_id) + return payload, output_file + + +def _write_data_value_set(payload: dict[str, Any], dataset_id: str) -> str: + """Persist DataValueSet payload and return file path.""" + DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True) + now = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%dT%H%M%SZ") + path = DOWNLOAD_DIR / f"{dataset_id}_datavalueset_{now}.json" + path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + return str(path) + + +def _format_period(time_value: Any, period_type: PeriodType) -> str: + ts = np.datetime64(time_value) + s = np.datetime_as_string(ts, unit="D") + year, month, day = s.split("-") + if period_type == PeriodType.DAILY: + return f"{year}{month}{day}" + if period_type == PeriodType.MONTHLY: + return f"{year}{month}" + if period_type == PeriodType.YEARLY: + return year + return s.replace("-", "") diff --git a/src/eo_api/workflows/services/definitions.py b/src/eo_api/workflows/services/definitions.py new file mode 100644 index 0000000..9876406 --- /dev/null +++ b/src/eo_api/workflows/services/definitions.py @@ -0,0 +1,132 @@ +"""Declarative workflow definition loading and validation.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Final, Literal + +import yaml +from pydantic import BaseModel, model_validator + +ComponentName = Literal[ + "feature_source", + "download_dataset", + "temporal_aggregation", + "spatial_aggregation", + "build_datavalueset", +] + +SUPPORTED_COMPONENTS: Final[set[str]] = set(ComponentName.__args__) # type: ignore[attr-defined] + +COMPONENT_INPUTS: Final[dict[str, set[str]]] = { + "feature_source": set(), + "download_dataset": {"bbox"}, + "temporal_aggregation": {"bbox"}, + "spatial_aggregation": {"bbox", "features"}, + "build_datavalueset": {"records"}, +} + +COMPONENT_OUTPUTS: Final[dict[str, set[str]]] = { + "feature_source": {"features", "bbox"}, + "download_dataset": set(), + "temporal_aggregation": {"temporal_dataset"}, + "spatial_aggregation": {"records"}, + "build_datavalueset": {"data_value_set", "output_file"}, +} + +SCRIPT_DIR = Path(__file__).parent.resolve() +WORKFLOWS_DIR = SCRIPT_DIR.parent.parent.parent.parent / "data" / "workflows" +DEFAULT_WORKFLOW_ID = "dhis2_datavalue_set_v1" + + +class WorkflowStep(BaseModel): + """One component step in a declarative workflow definition.""" + + component: ComponentName + + +class WorkflowDefinition(BaseModel): + """Declarative workflow definition.""" + + workflow_id: str + version: int = 1 + steps: list[WorkflowStep] + + @model_validator(mode="after") + def validate_steps(self) -> "WorkflowDefinition": + """Require terminal DataValueSet step and validate component compatibility.""" + if not self.steps: + raise ValueError("Workflow steps cannot be empty") + if self.steps[-1].component != "build_datavalueset": + raise ValueError("The last workflow step must be 'build_datavalueset'") + available_context: set[str] = set() + for step in self.steps: + required_inputs = COMPONENT_INPUTS[step.component] + missing_inputs = required_inputs - available_context + if missing_inputs: + missing = ", ".join(sorted(missing_inputs)) + raise ValueError(f"Component '{step.component}' is missing required upstream outputs: {missing}") + available_context.update(COMPONENT_OUTPUTS[step.component]) + return self + + +def load_workflow_definition( + workflow_id: str = DEFAULT_WORKFLOW_ID, + *, + path: Path | None = None, +) -> WorkflowDefinition: + """Load and validate workflow definition from discovered YAML files.""" + if path is not None: + workflow_file = path + else: + workflow_files = _discover_workflow_files() + workflow_file_or_none = workflow_files.get(workflow_id) + if workflow_file_or_none is None: + known = ", ".join(sorted(workflow_files)) + raise ValueError(f"Unknown workflow_id '{workflow_id}'. Allowed values: {known}") + workflow_file = workflow_file_or_none + + if not workflow_file.exists(): + raise ValueError(f"Workflow definition file not found: {workflow_file}") + with open(workflow_file, encoding="utf-8") as f: + raw = yaml.safe_load(f) + if raw is None: + raise ValueError(f"Workflow definition file is empty: {workflow_file}") + definition = WorkflowDefinition.model_validate(raw) + if path is None and definition.workflow_id != workflow_id: + raise ValueError( + f"workflow_id mismatch: requested '{workflow_id}' but definition declares '{definition.workflow_id}'" + ) + return definition + + +def list_workflow_definitions() -> list[WorkflowDefinition]: + """Load and return all discovered workflow definitions.""" + workflow_files = _discover_workflow_files() + return [load_workflow_definition(workflow_id) for workflow_id in sorted(workflow_files)] + + +def _discover_workflow_files() -> dict[str, Path]: + """Discover and validate workflow IDs from all YAML files in workflows folder.""" + if not WORKFLOWS_DIR.is_dir(): + raise ValueError(f"Workflow directory not found: {WORKFLOWS_DIR}") + + discovered: dict[str, Path] = {} + for workflow_file in sorted(WORKFLOWS_DIR.glob("*.y*ml")): + with open(workflow_file, encoding="utf-8") as f: + raw = yaml.safe_load(f) + if raw is None: + raise ValueError(f"Workflow definition file is empty: {workflow_file}") + if not isinstance(raw, dict): + raise ValueError(f"Workflow definition must be a mapping/object: {workflow_file}") + + workflow_id = raw.get("workflow_id") + if not isinstance(workflow_id, str) or not workflow_id: + raise ValueError(f"Missing/invalid workflow_id in: {workflow_file}") + + existing = discovered.get(workflow_id) + if existing is not None: + raise ValueError(f"Duplicate workflow_id '{workflow_id}' in files: {existing.name}, {workflow_file.name}") + discovered[workflow_id] = workflow_file + + return discovered diff --git a/src/eo_api/workflows/services/engine.py b/src/eo_api/workflows/services/engine.py new file mode 100644 index 0000000..9e7da9e --- /dev/null +++ b/src/eo_api/workflows/services/engine.py @@ -0,0 +1,207 @@ +"""Workflow orchestration engine for gridded-data pipelines.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import HTTPException + +from ...components import services as component_services +from ...data_registry.services.datasets import get_dataset +from ..schemas import WorkflowExecuteRequest, WorkflowExecuteResponse +from .definitions import WorkflowDefinition, load_workflow_definition +from .run_logs import persist_run_log +from .runtime import WorkflowRuntime + + +def execute_workflow( + request: WorkflowExecuteRequest, + *, + workflow_id: str = "dhis2_datavalue_set_v1", + include_component_run_details: bool = False, +) -> WorkflowExecuteResponse: + """Execute the feature->download->aggregate->DataValueSet workflow.""" + runtime = WorkflowRuntime() + + dataset = get_dataset(request.dataset_id) + if dataset is None: + raise HTTPException(status_code=404, detail=f"Dataset '{request.dataset_id}' not found") + + context: dict[str, Any] = {} + + try: + try: + workflow = load_workflow_definition(workflow_id) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + _execute_workflow_steps( + workflow=workflow, + runtime=runtime, + request=request, + dataset=dataset, + context=context, + ) + features = _require_context(context, "features") + bbox = _require_context(context, "bbox") + data_value_set = _require_context(context, "data_value_set") + output_file = _require_context(context, "output_file") + run_log_file = persist_run_log( + run_id=runtime.run_id, + request=request, + component_runs=runtime.component_runs, + status="completed", + output_file=output_file, + ) + + return WorkflowExecuteResponse( + status="completed", + run_id=runtime.run_id, + workflow_id=workflow.workflow_id, + workflow_version=workflow.version, + dataset_id=request.dataset_id, + bbox=bbox, + feature_count=len(features["features"]), + value_count=len(data_value_set["dataValues"]), + output_file=output_file, + run_log_file=run_log_file, + data_value_set=data_value_set, + component_runs=runtime.component_runs if include_component_run_details else [], + component_run_details_included=include_component_run_details, + component_run_details_available=True, + ) + except HTTPException: + persist_run_log( + run_id=runtime.run_id, + request=request, + component_runs=runtime.component_runs, + status="failed", + error="http_exception", + ) + raise + except Exception as exc: + persist_run_log( + run_id=runtime.run_id, + request=request, + component_runs=runtime.component_runs, + status="failed", + error=str(exc), + ) + last_component = runtime.component_runs[-1].component if runtime.component_runs else "unknown" + if _is_upstream_connectivity_error(exc): + raise HTTPException( + status_code=503, + detail={ + "error": "upstream_unreachable", + "message": "Could not reach upstream data source. Check network/proxy and retry.", + "failed_component": last_component, + "run_id": runtime.run_id, + }, + ) from exc + raise HTTPException( + status_code=500, + detail={ + "error": "workflow_execution_failed", + "message": str(exc), + "failed_component": last_component, + "run_id": runtime.run_id, + }, + ) from exc + + +def _is_upstream_connectivity_error(exc: Exception) -> bool: + message = str(exc).lower() + patterns = ( + "could not connect to server", + "failed to connect", + "connection refused", + "name or service not known", + "temporary failure in name resolution", + "timed out", + "curl error", + ) + return any(pattern in message for pattern in patterns) + + +def _execute_workflow_steps( + *, + workflow: WorkflowDefinition, + runtime: WorkflowRuntime, + request: WorkflowExecuteRequest, + dataset: dict[str, Any], + context: dict[str, Any], +) -> None: + """Execute workflow components using declarative YAML step order.""" + for step in workflow.steps: + if step.component == "feature_source": + features, bbox = runtime.run( + "feature_source", + component_services.feature_source_component, + config=request.feature_source, + ) + context["features"] = features + context["bbox"] = bbox + continue + + if step.component == "download_dataset": + runtime.run( + "download_dataset", + component_services.download_dataset_component, + dataset=dataset, + start=request.start, + end=request.end, + overwrite=request.overwrite, + country_code=request.country_code, + bbox=_require_context(context, "bbox"), + ) + continue + + if step.component == "temporal_aggregation": + temporal_ds = runtime.run( + "temporal_aggregation", + component_services.temporal_aggregation_component, + dataset=dataset, + start=request.start, + end=request.end, + bbox=_require_context(context, "bbox"), + target_period_type=request.temporal_aggregation.target_period_type, + method=request.temporal_aggregation.method, + ) + context["temporal_dataset"] = temporal_ds + continue + + if step.component == "spatial_aggregation": + records = runtime.run( + "spatial_aggregation", + component_services.spatial_aggregation_component, + dataset=dataset, + start=request.start, + end=request.end, + bbox=_require_context(context, "bbox"), + features=_require_context(context, "features"), + method=request.spatial_aggregation.method, + feature_id_property=request.dhis2.org_unit_property, + ) + context["records"] = records + continue + + if step.component == "build_datavalueset": + data_value_set, output_file = runtime.run( + "build_datavalueset", + component_services.build_datavalueset_component, + records=_require_context(context, "records"), + dataset_id=request.dataset_id, + period_type=request.temporal_aggregation.target_period_type, + dhis2=request.dhis2, + ) + context["data_value_set"] = data_value_set + context["output_file"] = output_file + continue + + raise RuntimeError(f"Unsupported workflow component '{step.component}'") + + +def _require_context(context: dict[str, Any], key: str) -> Any: + """Return required context value or raise a clear orchestration error.""" + if key not in context: + raise RuntimeError(f"Workflow definition missing prerequisite for '{key}'") + return context[key] diff --git a/src/eo_api/workflows/services/features.py b/src/eo_api/workflows/services/features.py new file mode 100644 index 0000000..95d8cc4 --- /dev/null +++ b/src/eo_api/workflows/services/features.py @@ -0,0 +1,69 @@ +"""Feature source component for workflow execution.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import geopandas as gpd + +from ...shared.dhis2_adapter import create_client, get_org_unit_geojson, get_org_units_geojson +from ..schemas import FeatureSourceConfig, FeatureSourceType + + +def resolve_features(config: FeatureSourceConfig) -> tuple[dict[str, Any], list[float]]: + """Resolve features from a source and return FeatureCollection + bbox.""" + if config.source_type == FeatureSourceType.GEOJSON_FILE: + collection = _read_geojson_file(config.geojson_path or "") + elif config.source_type == FeatureSourceType.DHIS2_LEVEL: + client = create_client() + collection = get_org_units_geojson(client, level=config.dhis2_level, parent=config.dhis2_parent) + else: + client = create_client() + collection = _collection_from_dhis2_ids(client, config.dhis2_ids or []) + + collection = _normalize_feature_collection(collection) + bbox = _bbox_from_feature_collection(collection) + return collection, bbox + + +def feature_id(feature: dict[str, Any], key: str) -> str: + """Get feature identifier from properties, feature id, or UID fallbacks.""" + properties = feature.get("properties", {}) + value = properties.get(key) or feature.get("id") or properties.get("id") or properties.get("uid") + if value is None: + raise ValueError(f"Unable to find feature identifier using key '{key}'") + return str(value) + + +def _read_geojson_file(path: str) -> dict[str, Any]: + raw = json.loads(Path(path).read_text(encoding="utf-8")) + return _normalize_feature_collection(raw) + + +def _collection_from_dhis2_ids(client: Any, ou_ids: list[str]) -> dict[str, Any]: + features: list[dict[str, Any]] = [] + for uid in ou_ids: + unit_geojson = get_org_unit_geojson(client, uid) + normalized = _normalize_feature_collection(unit_geojson) + features.extend(normalized["features"]) + return {"type": "FeatureCollection", "features": features} + + +def _normalize_feature_collection(raw: dict[str, Any]) -> dict[str, Any]: + raw_type = raw.get("type") + if raw_type == "FeatureCollection": + return raw + if raw_type == "Feature": + return {"type": "FeatureCollection", "features": [raw]} + if "features" in raw and isinstance(raw["features"], list): + return {"type": "FeatureCollection", "features": raw["features"]} + raise ValueError("Input is not a valid GeoJSON feature or feature collection") + + +def _bbox_from_feature_collection(collection: dict[str, Any]) -> list[float]: + if not collection.get("features"): + raise ValueError("Feature collection is empty") + bounds = gpd.read_file(json.dumps(collection)).total_bounds + return [float(v) for v in bounds] diff --git a/src/eo_api/workflows/services/preflight.py b/src/eo_api/workflows/services/preflight.py new file mode 100644 index 0000000..9ff6c3f --- /dev/null +++ b/src/eo_api/workflows/services/preflight.py @@ -0,0 +1,23 @@ +"""Preflight checks for external data source connectivity.""" + +from __future__ import annotations + +import socket +from urllib.parse import urlparse + + +def check_upstream_connectivity(dataset: dict[str, object], timeout_seconds: float = 5.0) -> None: + """Fail fast if a dataset source host is not reachable.""" + source_url = dataset.get("source_url") + if not isinstance(source_url, str) or not source_url: + return + + parsed = urlparse(source_url) + hostname = parsed.hostname + if not hostname: + return + port = parsed.port or (443 if parsed.scheme == "https" else 80) + + # Fail quickly on DNS/TCP connectivity issues instead of waiting for long GDAL timeouts. + with socket.create_connection((hostname, port), timeout=timeout_seconds): + pass diff --git a/src/eo_api/workflows/services/run_logs.py b/src/eo_api/workflows/services/run_logs.py new file mode 100644 index 0000000..ea4d375 --- /dev/null +++ b/src/eo_api/workflows/services/run_logs.py @@ -0,0 +1,37 @@ +"""Run-log persistence for workflow executions.""" + +from __future__ import annotations + +import datetime as dt +import json +from typing import Any + +from ...data_manager.services.downloader import DOWNLOAD_DIR +from ..schemas import ComponentRun, WorkflowExecuteRequest + + +def persist_run_log( + *, + run_id: str, + request: WorkflowExecuteRequest, + component_runs: list[ComponentRun], + status: str, + output_file: str | None = None, + error: str | None = None, +) -> str: + """Write workflow run metadata to disk and return file path.""" + logs_dir = DOWNLOAD_DIR / "workflow_runs" + logs_dir.mkdir(parents=True, exist_ok=True) + timestamp = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%dT%H%M%SZ") + path = logs_dir / f"{timestamp}_{run_id}.json" + + payload: dict[str, Any] = { + "run_id": run_id, + "status": status, + "request": request.model_dump(mode="json"), + "component_runs": [run.model_dump(mode="json") for run in component_runs], + "output_file": output_file, + "error": error, + } + path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + return str(path) diff --git a/src/eo_api/workflows/services/runtime.py b/src/eo_api/workflows/services/runtime.py new file mode 100644 index 0000000..905ba54 --- /dev/null +++ b/src/eo_api/workflows/services/runtime.py @@ -0,0 +1,89 @@ +"""Component runtime wrapper for workflow housekeeping metadata.""" + +from __future__ import annotations + +import datetime as dt +import time +import uuid +from collections.abc import Callable +from typing import Any + +from ..schemas import ComponentRun + + +class WorkflowRuntime: + """Capture execution metadata for component orchestration.""" + + def __init__(self) -> None: + self.run_id = str(uuid.uuid4()) + self.component_runs: list[ComponentRun] = [] + + def run(self, component: str, fn: Callable[..., Any], **kwargs: Any) -> Any: + """Execute one component and record start/end/input/output metadata.""" + started = dt.datetime.now(dt.timezone.utc) + started_perf = time.perf_counter() + + try: + result = fn(**kwargs) + ended = dt.datetime.now(dt.timezone.utc) + self.component_runs.append( + ComponentRun( + component=component, + status="completed", + started_at=started.isoformat(), + ended_at=ended.isoformat(), + duration_ms=int((time.perf_counter() - started_perf) * 1000), + inputs=_to_json_summary(kwargs), + outputs={"result": _to_json_summary(result)}, + ) + ) + return result + except Exception as exc: + ended = dt.datetime.now(dt.timezone.utc) + self.component_runs.append( + ComponentRun( + component=component, + status="failed", + started_at=started.isoformat(), + ended_at=ended.isoformat(), + duration_ms=int((time.perf_counter() - started_perf) * 1000), + inputs=_to_json_summary(kwargs), + outputs=None, + error=str(exc), + ) + ) + raise + + +def _to_json_summary(value: Any, *, depth: int = 0, max_depth: int = 2) -> Any: + """Convert arbitrary values into a compact JSON-safe summary.""" + if depth >= max_depth: + return _fallback_summary(value) + + if value is None or isinstance(value, (str, int, float, bool)): + return value + + if isinstance(value, list): + return [_to_json_summary(v, depth=depth + 1, max_depth=max_depth) for v in value[:20]] + + if isinstance(value, tuple): + return [_to_json_summary(v, depth=depth + 1, max_depth=max_depth) for v in value[:20]] + + if isinstance(value, dict): + out: dict[str, Any] = {} + for i, (k, v) in enumerate(value.items()): + if i >= 30: + out["..."] = "truncated" + break + out[str(k)] = _to_json_summary(v, depth=depth + 1, max_depth=max_depth) + return out + + return _fallback_summary(value) + + +def _fallback_summary(value: Any) -> str: + if hasattr(value, "shape"): + return f"{type(value).__name__}(shape={getattr(value, 'shape')})" + if hasattr(value, "sizes"): + return f"{type(value).__name__}(sizes={getattr(value, 'sizes')})" + return type(value).__name__ diff --git a/src/eo_api/workflows/services/simple_mapper.py b/src/eo_api/workflows/services/simple_mapper.py new file mode 100644 index 0000000..74b8758 --- /dev/null +++ b/src/eo_api/workflows/services/simple_mapper.py @@ -0,0 +1,86 @@ +"""Mapper from simplified workflow inputs to internal workflow request.""" + +from __future__ import annotations + +from fastapi import HTTPException + +from ...data_registry.services.datasets import get_dataset +from ..schemas import ( + AggregationMethod, + Dhis2DataValueSetConfig, + FeatureSourceConfig, + FeatureSourceType, + SpatialAggregationConfig, + TemporalAggregationConfig, + WorkflowExecuteRequest, + WorkflowRequest, +) + +_IGNORED_FIELDS = ["dry_run", "stage", "flavor", "output_format"] + + +def normalize_simple_request(payload: WorkflowRequest) -> tuple[WorkflowExecuteRequest, list[str]]: + """Translate public workflow request format to internal workflow request.""" + inputs = payload + dataset_id = inputs.dataset_id + dataset = get_dataset(dataset_id) + + period_type = str(dataset.get("period_type", "")).lower() if dataset else "" + + if inputs.start_date and inputs.end_date: + if period_type == "yearly": + start = inputs.start_date[:4] + end = inputs.end_date[:4] + elif period_type in {"hourly", "daily", "monthly"}: + # dhis2eo downloaders expect month windows for these dataset types. + start = inputs.start_date[:7] + end = inputs.end_date[:7] + else: + start = inputs.start_date + end = inputs.end_date + elif inputs.start_year is not None and inputs.end_year is not None: + if period_type == "yearly": + start = str(inputs.start_year) + end = str(inputs.end_year) + else: + start = f"{inputs.start_year}-01-01" + end = f"{inputs.end_year}-12-31" + else: + raise HTTPException(status_code=422, detail="Provide either start_date/end_date or start_year/end_year") + + if inputs.org_unit_level is not None: + feature_source = FeatureSourceConfig( + source_type=FeatureSourceType.DHIS2_LEVEL, + dhis2_level=inputs.org_unit_level, + feature_id_property=inputs.feature_id_property, + ) + elif inputs.org_unit_ids: + feature_source = FeatureSourceConfig( + source_type=FeatureSourceType.DHIS2_IDS, + dhis2_ids=inputs.org_unit_ids, + feature_id_property=inputs.feature_id_property, + ) + else: + raise HTTPException(status_code=422, detail="Provide org_unit_level or org_unit_ids") + + reducer_alias = AggregationMethod(inputs.reducer.lower()) if inputs.reducer else None + spatial_method = reducer_alias or inputs.spatial_reducer + temporal_method = reducer_alias or inputs.temporal_reducer + + normalized = WorkflowExecuteRequest( + dataset_id=dataset_id, + start=start, + end=end, + overwrite=inputs.overwrite, + country_code=inputs.country_code, + feature_source=feature_source, + temporal_aggregation=TemporalAggregationConfig( + target_period_type=inputs.temporal_resolution, + method=temporal_method, + ), + spatial_aggregation=SpatialAggregationConfig(method=spatial_method), + dhis2=Dhis2DataValueSetConfig(data_element_uid=inputs.data_element), + ) + + warnings = [f"Input field '{field}' is currently accepted but not used in execution" for field in _IGNORED_FIELDS] + return normalized, warnings diff --git a/src/eo_api/workflows/services/spatial.py b/src/eo_api/workflows/services/spatial.py new file mode 100644 index 0000000..668141e --- /dev/null +++ b/src/eo_api/workflows/services/spatial.py @@ -0,0 +1,56 @@ +"""Spatial aggregation component for gridded datasets.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import xarray as xr +from shapely import contains_xy +from shapely.geometry import shape + +from ...data_manager.services.utils import get_lon_lat_dims, get_time_dim +from .features import feature_id + + +def aggregate_to_features( + ds: xr.Dataset, + *, + variable: str, + features: dict[str, Any], + method: str, + feature_id_property: str, +) -> list[dict[str, Any]]: + """Aggregate one gridded variable into per-feature time series.""" + da = ds[variable] + time_dim = get_time_dim(da) + lon_dim, lat_dim = get_lon_lat_dims(da) + lon_values = da[lon_dim].values + lat_values = da[lat_dim].values + lon_grid, lat_grid = np.meshgrid(lon_values, lat_values) + + output: list[dict[str, Any]] = [] + for feature in features.get("features", []): + geom = shape(feature["geometry"]) + mask = contains_xy(geom, lon_grid, lat_grid) + if not np.any(mask): + continue + + mask_da = xr.DataArray( + mask, + dims=(lat_dim, lon_dim), + coords={lat_dim: da[lat_dim], lon_dim: da[lon_dim]}, + ) + reduced = getattr(da.where(mask_da), method)(dim=[lat_dim, lon_dim], skipna=True) + org_unit = feature_id(feature, feature_id_property) + for t, value in zip(reduced[time_dim].values, reduced.values, strict=True): + if np.isnan(value): + continue + output.append( + { + "org_unit": org_unit, + "time": t, + "value": float(value), + } + ) + return output diff --git a/src/eo_api/workflows/services/temporal.py b/src/eo_api/workflows/services/temporal.py new file mode 100644 index 0000000..e244f85 --- /dev/null +++ b/src/eo_api/workflows/services/temporal.py @@ -0,0 +1,25 @@ +"""Temporal aggregation component.""" + +from __future__ import annotations + +from typing import cast + +import xarray as xr + +from ...data_manager.services.utils import get_time_dim +from ..schemas import AggregationMethod, PeriodType + +_PERIOD_TO_FREQ: dict[PeriodType, str] = { + PeriodType.HOURLY: "1h", + PeriodType.DAILY: "1D", + PeriodType.MONTHLY: "MS", + PeriodType.YEARLY: "YS", +} + + +def aggregate_temporal(ds: xr.Dataset, *, period_type: PeriodType, method: AggregationMethod) -> xr.Dataset: + """Resample a dataset over the time dimension to the target period.""" + time_dim = get_time_dim(ds) + freq = _PERIOD_TO_FREQ[period_type] + resampled = ds.resample({time_dim: freq}) + return cast(xr.Dataset, getattr(resampled, method.value)(keep_attrs=True)) diff --git a/tests/test_root.py b/tests/test_root.py index 5353344..f90d3c6 100644 --- a/tests/test_root.py +++ b/tests/test_root.py @@ -1,6 +1,6 @@ from fastapi.testclient import TestClient -from eo_api.schemas import HealthStatus, RootResponse +from eo_api.system.schemas import HealthStatus, RootResponse def test_root_returns_200(client: TestClient) -> None: diff --git a/tests/test_workflows.py b/tests/test_workflows.py new file mode 100644 index 0000000..da16352 --- /dev/null +++ b/tests/test_workflows.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +from typing import Any + +import pytest +import xarray as xr +from fastapi import HTTPException +from fastapi.routing import APIRoute +from fastapi.testclient import TestClient + +from eo_api.main import app +from eo_api.workflows.schemas import WorkflowExecuteRequest, WorkflowExecuteResponse, WorkflowRequest +from eo_api.workflows.services import engine +from eo_api.workflows.services.definitions import WorkflowDefinition, load_workflow_definition +from eo_api.workflows.services.simple_mapper import normalize_simple_request + + +def _valid_public_payload() -> dict[str, Any]: + return { + "workflow_id": "dhis2_datavalue_set_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "org_unit_level": 3, + "data_element": "abc123def45", + "temporal_resolution": "monthly", + "temporal_reducer": "sum", + "spatial_reducer": "mean", + "dry_run": True, + "include_component_run_details": False, + } + + +def test_workflow_endpoint_exists_once() -> None: + workflow_routes = { + route.path + for route in app.routes + if isinstance(route, APIRoute) and route.path.startswith("/workflows") and "POST" in route.methods + } + assert workflow_routes == {"/workflows/dhis2-datavalue-set"} + + +def test_workflow_catalog_endpoint_returns_allowlisted_workflow(client: TestClient) -> None: + response = client.get("/workflows") + assert response.status_code == 200 + body = response.json() + assert "workflows" in body + assert len(body["workflows"]) >= 2 + by_id = {item["workflow_id"]: item for item in body["workflows"]} + + default = by_id["dhis2_datavalue_set_v1"] + assert default["version"] == 1 + assert default["step_count"] == 5 + assert default["components"] == [ + "feature_source", + "download_dataset", + "temporal_aggregation", + "spatial_aggregation", + "build_datavalueset", + ] + + fast = by_id["dhis2_datavalue_set_without_temporal_aggregation_v1"] + assert fast["version"] == 1 + assert fast["step_count"] == 4 + assert fast["components"] == [ + "feature_source", + "download_dataset", + "spatial_aggregation", + "build_datavalueset", + ] + + +def test_components_catalog_endpoint_returns_five_components(client: TestClient) -> None: + response = client.get("/components") + assert response.status_code == 200 + items = response.json()["components"] + names = {item["name"] for item in items} + assert names == { + "feature_source", + "download_dataset", + "temporal_aggregation", + "spatial_aggregation", + "build_datavalueset", + } + + +def test_workflow_endpoint_returns_response_shape(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + stub = WorkflowExecuteResponse( + status="completed", + run_id="run-123", + workflow_id="dhis2_datavalue_set_v1", + workflow_version=1, + dataset_id="chirps3_precipitation_daily", + bbox=[-13.3, 6.9, -10.1, 10.0], + feature_count=2, + value_count=4, + output_file="/tmp/data/chirps3_datavalueset.json", + run_log_file="/tmp/data/workflow_runs/run-123.json", + data_value_set={ + "dataValues": [ + { + "dataElement": "abc123def45", + "period": "202401", + "orgUnit": "OU_1", + "categoryOptionCombo": "HllvX50cXC0", + "attributeOptionCombo": "HllvX50cXC0", + "value": "12.3", + } + ] + }, + component_runs=[], + ) + monkeypatch.setattr( + "eo_api.workflows.routes.execute_workflow", + lambda payload, workflow_id="dhis2_datavalue_set_v1", include_component_run_details=False: stub, + ) + + response = client.post("/workflows/dhis2-datavalue-set", json=_valid_public_payload()) + assert response.status_code == 200 + body = response.json() + assert body["status"] == "completed" + assert body["run_id"] == "run-123" + assert body["workflow_id"] == "dhis2_datavalue_set_v1" + assert body["workflow_version"] == 1 + assert body["run_log_file"].endswith(".json") + assert "dataValues" in body["data_value_set"] + assert body["component_run_details_included"] is False + assert body["component_run_details_available"] is True + + +def test_workflow_endpoint_validates_required_fields(client: TestClient) -> None: + payload = _valid_public_payload() + payload.pop("org_unit_level") + + response = client.post("/workflows/dhis2-datavalue-set", json=payload) + assert response.status_code == 422 + + +def test_workflow_endpoint_accepts_simplified_payload(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + normalized = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3, "feature_id_property": "id"}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + stub = WorkflowExecuteResponse( + status="completed", + run_id="run-123", + workflow_id="dhis2_datavalue_set_v1", + workflow_version=1, + dataset_id="chirps3_precipitation_daily", + bbox=[-13.3, 6.9, -10.1, 10.0], + feature_count=2, + value_count=4, + output_file="/tmp/data/chirps3_datavalueset.json", + run_log_file="/tmp/data/workflow_runs/run-123.json", + data_value_set={"dataValues": []}, + component_runs=[], + ) + monkeypatch.setattr("eo_api.workflows.routes.normalize_simple_request", lambda payload: (normalized, [])) + monkeypatch.setattr( + "eo_api.workflows.routes.execute_workflow", + lambda payload, workflow_id="dhis2_datavalue_set_v1", include_component_run_details=False: stub, + ) + + response = client.post("/workflows/dhis2-datavalue-set", json=_valid_public_payload()) + assert response.status_code == 200 + assert response.json()["status"] == "completed" + + +def test_engine_orchestrates_components(monkeypatch: pytest.MonkeyPatch) -> None: + request = { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "country_code": "SLE", + "feature_source": { + "source_type": "geojson_file", + "geojson_path": "tests/data/sierra_leone_districts.geojson", + "feature_id_property": "id", + }, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + + dataset = {"id": "chirps3_precipitation_daily", "variable": "precip"} + ds = xr.Dataset( + {"precip": (("time", "lat", "lon"), [[[1.0]]])}, + coords={"time": ["2024-01-01"], "lat": [0], "lon": [0]}, + ) + + monkeypatch.setattr(engine, "get_dataset", lambda dataset_id: dataset) + + called: dict[str, Any] = {"downloaded": False} + + def _download_dataset_component(**kwargs: Any) -> None: + called["downloaded"] = True + assert kwargs["bbox"] == [0.0, 0.0, 1.0, 1.0] + assert kwargs["country_code"] == "SLE" + + monkeypatch.setattr( + engine.component_services, + "feature_source_component", + lambda config: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0.0, 0.0, 1.0, 1.0], + ), + ) + monkeypatch.setattr(engine.component_services, "download_dataset_component", _download_dataset_component) + monkeypatch.setattr(engine.component_services, "temporal_aggregation_component", lambda **kwargs: ds) + monkeypatch.setattr( + engine.component_services, + "spatial_aggregation_component", + lambda **kwargs: [{"org_unit": "OU_1", "time": "2024-01-01", "value": 10.0}], + ) + monkeypatch.setattr( + engine.component_services, + "build_datavalueset_component", + lambda **kwargs: ({"dataValues": [{"value": "10.0"}]}, "/tmp/data/out.json"), + ) + monkeypatch.setattr(engine, "persist_run_log", lambda **kwargs: "/tmp/data/workflow_runs/run.json") + + response = engine.execute_workflow( + engine.WorkflowExecuteRequest.model_validate(request), + include_component_run_details=True, + ) + assert response.status == "completed" + assert response.run_id + assert response.value_count == 1 + assert response.run_log_file.endswith(".json") + assert len(response.component_runs) == 5 + assert [c.component for c in response.component_runs] == [ + "feature_source", + "download_dataset", + "temporal_aggregation", + "spatial_aggregation", + "build_datavalueset", + ] + assert response.component_run_details_included is True + assert response.component_run_details_available is True + assert called["downloaded"] is True + + +def test_engine_hides_component_details_by_default(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + ds = xr.Dataset( + {"precip": (("time", "lat", "lon"), [[[1.0]]])}, + coords={"time": ["2024-01-01"], "lat": [0], "lon": [0]}, + ) + monkeypatch.setattr( + engine, + "get_dataset", + lambda dataset_id: {"id": "chirps3_precipitation_daily", "variable": "precip"}, + ) + monkeypatch.setattr( + engine.component_services, + "feature_source_component", + lambda config: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0, 0, 1, 1], + ), + ) + monkeypatch.setattr(engine.component_services, "download_dataset_component", lambda **kwargs: None) + monkeypatch.setattr(engine.component_services, "temporal_aggregation_component", lambda **kwargs: ds) + monkeypatch.setattr( + engine.component_services, + "spatial_aggregation_component", + lambda **kwargs: [{"org_unit": "OU_1", "time": "2024-01-01", "value": 10.0}], + ) + monkeypatch.setattr( + engine.component_services, + "build_datavalueset_component", + lambda **kwargs: ({"dataValues": [{"value": "10.0"}]}, "/tmp/data/out.json"), + ) + monkeypatch.setattr(engine, "persist_run_log", lambda **kwargs: "/tmp/data/workflow_runs/run.json") + + response = engine.execute_workflow(request) + assert response.component_runs == [] + assert response.component_run_details_included is False + assert response.component_run_details_available is True + + +def test_engine_returns_503_when_upstream_unreachable(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + monkeypatch.setattr(engine, "get_dataset", lambda dataset_id: {"id": "chirps3_precipitation_daily"}) + monkeypatch.setattr( + engine.component_services, + "feature_source_component", + lambda config: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0, 0, 1, 1], + ), + ) + monkeypatch.setattr( + engine.component_services, + "download_dataset_component", + lambda **kwargs: (_ for _ in ()).throw(RuntimeError("Failed to connect to server")), + ) + monkeypatch.setattr(engine, "persist_run_log", lambda **kwargs: "/tmp/data/workflow_runs/run.json") + + with pytest.raises(HTTPException) as exc_info: + engine.execute_workflow(request) + + assert exc_info.value.status_code == 503 + + +def test_mapper_uses_year_format_for_yearly_dataset() -> None: + normalized, _warnings = normalize_simple_request( + WorkflowRequest.model_validate( + { + "dataset_id": "worldpop_population_yearly", + "country_code": "SLE", + "start_year": 2015, + "end_year": 2026, + "org_unit_level": 2, + "data_element": "DE_UID", + "temporal_resolution": "yearly", + } + ) + ) + assert normalized.start == "2015" + assert normalized.end == "2026" + + +def test_mapper_reducer_alias_overrides_spatial_and_temporal_reducers() -> None: + normalized, _warnings = normalize_simple_request( + WorkflowRequest.model_validate( + { + "dataset_id": "worldpop_population_yearly", + "country_code": "SLE", + "start_year": 2015, + "end_year": 2026, + "org_unit_level": 2, + "data_element": "DE_UID", + "temporal_resolution": "yearly", + "reducer": "sum", + } + ) + ) + assert normalized.spatial_aggregation.method.value == "sum" + assert normalized.temporal_aggregation.method.value == "sum" + + +def test_mapper_uses_month_format_for_chirps_date_window() -> None: + normalized, _warnings = normalize_simple_request( + WorkflowRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-05-31", + "org_unit_level": 2, + "data_element": "DE_UID", + } + ) + ) + assert normalized.start == "2024-01" + assert normalized.end == "2024-05" + + +def test_default_workflow_definition_has_expected_steps() -> None: + workflow = load_workflow_definition() + assert workflow.workflow_id == "dhis2_datavalue_set_v1" + assert workflow.version == 1 + assert [step.component for step in workflow.steps] == [ + "feature_source", + "download_dataset", + "temporal_aggregation", + "spatial_aggregation", + "build_datavalueset", + ] + + +def test_engine_follows_declarative_workflow_order(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + ds = xr.Dataset( + {"precip": (("time", "lat", "lon"), [[[1.0]]])}, + coords={"time": ["2024-01-01"], "lat": [0], "lon": [0]}, + ) + monkeypatch.setattr( + engine, + "load_workflow_definition", + lambda workflow_id: WorkflowDefinition.model_validate( + { + "workflow_id": workflow_id, + "version": 1, + "steps": [ + {"component": "feature_source"}, + {"component": "download_dataset"}, + {"component": "spatial_aggregation"}, + {"component": "build_datavalueset"}, + ], + } + ), + ) + monkeypatch.setattr( + engine, + "get_dataset", + lambda dataset_id: {"id": "chirps3_precipitation_daily", "variable": "precip"}, + ) + monkeypatch.setattr( + engine.component_services, + "feature_source_component", + lambda config: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0, 0, 1, 1], + ), + ) + monkeypatch.setattr(engine.component_services, "download_dataset_component", lambda **kwargs: None) + monkeypatch.setattr(engine.component_services, "temporal_aggregation_component", lambda **kwargs: ds) + monkeypatch.setattr( + engine.component_services, + "spatial_aggregation_component", + lambda **kwargs: [{"org_unit": "OU_1", "time": "2024-01-01", "value": 10.0}], + ) + monkeypatch.setattr( + engine.component_services, + "build_datavalueset_component", + lambda **kwargs: ({"dataValues": [{"value": "10.0"}]}, "/tmp/data/out.json"), + ) + monkeypatch.setattr(engine, "persist_run_log", lambda **kwargs: "/tmp/data/workflow_runs/run.json") + + response = engine.execute_workflow(request, include_component_run_details=True) + assert response.workflow_id == "dhis2_datavalue_set_v1" + assert response.workflow_version == 1 + assert [c.component for c in response.component_runs] == [ + "feature_source", + "download_dataset", + "spatial_aggregation", + "build_datavalueset", + ] + + +def test_engine_rejects_unknown_workflow_id(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + monkeypatch.setattr( + engine, + "get_dataset", + lambda dataset_id: {"id": "chirps3_precipitation_daily", "variable": "precip"}, + ) + + with pytest.raises(HTTPException) as exc_info: + engine.execute_workflow(request, workflow_id="not_allowlisted") + + assert exc_info.value.status_code == 422 From feed7959c77df53fa7874243979980d285138012 Mon Sep 17 00:00:00 2001 From: Abyot Asalefew Gizaw Date: Wed, 11 Mar 2026 23:34:19 +0100 Subject: [PATCH 2/5] Document workflow architecture and manual E2E testing --- README.md | 18 ++ docs/workflow-orchestration.md | 517 +++++++++++++++++++++++++++++++++ 2 files changed, 535 insertions(+) create mode 100644 docs/workflow-orchestration.md diff --git a/README.md b/README.md index ce64dec..224f48b 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,24 @@ Docs: http://127.0.0.1:8000/docs +Workflow (single payload contract): + +`POST /workflows/dhis2-datavalue-set` + +```json +{ + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-05-31", + "org_unit_level": 3, + "data_element": "DE_UID", + "temporal_resolution": "monthly", + "temporal_reducer": "sum", + "spatial_reducer": "mean", + "dry_run": true +} +``` + OGC API http://127.0.0.1:8000/ogcapi diff --git a/docs/workflow-orchestration.md b/docs/workflow-orchestration.md new file mode 100644 index 0000000..22315c8 --- /dev/null +++ b/docs/workflow-orchestration.md @@ -0,0 +1,517 @@ +# Workflow Orchestration Design (Single Endpoint, Componentized Runtime) + +## Purpose + +This document describes the implemented approach for generating a DHIS2 DataValueSet from gridded EO datasets through one workflow endpoint and reusable components. + +It documents: + +1. What has been achieved. +2. The architecture and execution flow. +3. Public API contract and normalization rules. +4. Runtime metadata, observability, and error handling. +5. Current componentization strategy and extension path. + +--- + +## What Is Implemented + +The current implementation provides: + +1. One canonical workflow execution endpoint: + - `POST /workflows/dhis2-datavalue-set` +2. One public flat request payload contract (`WorkflowRequest`). +3. Internal normalization into a canonical execution model (`WorkflowExecuteRequest`). +4. A fixed generic orchestration chain with exactly 5 components: + - `feature_source` + - `download_dataset` + - `temporal_aggregation` + - `spatial_aggregation` + - `build_datavalueset` +5. Per-component runtime instrumentation (`WorkflowRuntime`) with timing, status, and summarized inputs/outputs. +6. Run-log persistence for both success and failure. +7. Structured error responses, including upstream connectivity failures. +8. Optional inclusion of detailed component run traces in API responses. +9. Discoverable standalone component endpoints under `/components` for direct execution and future orchestrator integration. +10. Declarative workflow assembly via YAML (`data/workflows/dhis2_datavalue_set.yaml`) executed by the workflow engine. + +--- + +## Final API Surface + +### Primary Workflow Endpoint + +- `POST /workflows/dhis2-datavalue-set` + +### Workflow Discovery Endpoint + +- `GET /workflows` (discovered workflow catalog from `data/workflows/*.yaml` with `workflow_id`, `version`, and component chain) + +### Component Discovery/Execution Endpoints + +- `GET /components` +- `POST /components/feature-source` +- `POST /components/download-dataset` +- `POST /components/temporal-aggregation` +- `POST /components/spatial-aggregation` +- `POST /components/build-datavalue-set` + +`/components/*` endpoints are for reusable task-level execution. The workflow endpoint remains the single end-to-end API for generating DHIS2 DataValueSet output. + +--- + +## Public Workflow Request Contract + +The workflow endpoint accepts one flat payload shape: + +```json +{ + "workflow_id": "dhis2_datavalue_set_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-05-31", + "org_unit_level": 2, + "data_element": "DE_UID", + "temporal_resolution": "monthly", + "temporal_reducer": "sum", + "spatial_reducer": "mean", + "include_component_run_details": false +} +``` + +Important fields: + +1. `dataset_id` (required) +2. `workflow_id` (optional, default `dhis2_datavalue_set_v1`, must exist in discovered workflow YAMLs) +3. Time window (required as one of): + - `start_date` + `end_date`, or + - `start_year` + `end_year` +4. Spatial scope (required as one of): + - `org_unit_level`, or + - `org_unit_ids` +5. `data_element` (required) +6. `include_component_run_details` (optional, default `false`) + +Notes: + +1. `feature_id_property` defaults to `"id"` and controls which feature property maps to DHIS2 org unit ID in spatial aggregation/DataValueSet construction. +2. `country_code` is accepted in request and passed to dataset downloaders (instead of forcing `.env` only). +3. `reducer` is accepted as an alias and mapped to both temporal and spatial reducer when provided. + +--- + +## Normalization and Mapping Approach + +File: `src/eo_api/workflows/services/simple_mapper.py` + +Public flat payload is normalized to internal `WorkflowExecuteRequest` with component-ready nested configs: + +1. `feature_source` config: + - `org_unit_level` -> `source_type=dhis2_level` + - `org_unit_ids` -> `source_type=dhis2_ids` +2. `temporal_aggregation` config: + - `target_period_type` from `temporal_resolution` + - `method` from `temporal_reducer` (or `reducer` alias) +3. `spatial_aggregation` config: + - `method` from `spatial_reducer` (or `reducer` alias) +4. `dhis2` config: + - `data_element_uid` from `data_element` + +Time normalization depends on dataset registry metadata (`period_type`): + +1. Yearly datasets: + - normalize to `YYYY` +2. Hourly/Daily/Monthly datasets: + - normalize to month windows (`YYYY-MM`) for downloader compatibility +3. Fallback: + - pass date strings as provided + +This mapping keeps the public contract simple while preserving an extensible internal orchestration model. + +--- + +## Architecture + +### API Routing Layer + +Files: + +1. `src/eo_api/workflows/routes.py` +2. `src/eo_api/components/routes.py` +3. `src/eo_api/main.py` + +Responsibilities: + +1. Expose one workflow endpoint and reusable component endpoints. +2. Keep payload and response models explicit with Pydantic. +3. Delegate execution logic to service layers. + +### Workflow Engine Layer + +File: `src/eo_api/workflows/services/engine.py` + +Responsibilities: + +1. Validate dataset existence via registry. +2. Execute the 5 components in fixed order. +3. Collect runtime telemetry for each component. +4. Persist run logs on both success and error paths. +5. Return workflow result with optional component-run detail inclusion. + +### Workflow Definition Layer + +Files: + +1. `src/eo_api/workflows/services/definitions.py` +2. `data/workflows/dhis2_datavalue_set.yaml` + +Responsibilities: + +1. Discover, load, and validate declarative workflow definitions from `data/workflows/*.yaml`. +2. Enforce supported component names. +3. Enforce terminal `build_datavalueset` step for this end-to-end workflow. +4. Enforce output-to-input compatibility across the full accumulated context (not just adjacent steps). +5. Drive runtime execution order from YAML instead of hardcoded sequence. + +### Reusable Component Service Layer + +File: `src/eo_api/components/services.py` + +Responsibilities: + +1. Provide discoverable component catalog metadata. +2. Implement component functions used by: + - workflow engine, and + - `/components/*` task endpoints. +3. Reuse existing EO API capabilities (`downloader`, `accessor`, temporal/spatial aggregators, DataValueSet builder). + +--- + +## Layering Rationale + +The repository uses three layers with different responsibilities: + +1. `data_xxx` services (`data_manager`, `data_accessor`, `data_registry`) + - Core domain capabilities (download, load/subset, dataset metadata). + - No workflow-specific orchestration state required. +2. `components/` + - Thin reusable wrappers around core capabilities. + - Standardized component contracts for discovery (`GET /components`) and direct task execution. + - Runtime-friendly boundaries for future orchestrators (Prefect/Airflow). +3. `workflows/` + - End-to-end orchestration, request normalization, workflow selection, runtime tracing, and run-log persistence. + - Declarative assembly from `data/workflows/*.yaml`. + +Example: + +1. `download_dataset` workflow/component step delegates actual download work to `src/eo_api/data_manager/services/downloader.py`. +2. The wrapper adds orchestration-level concerns (preflight, context wiring, component runtime metadata) without duplicating downloader logic. + +This separation keeps core data services reusable and prevents workflow-specific concerns from leaking into the low-level data modules. + +--- + +## Component Chain (Exact Runtime Order) + +The workflow engine executes these components, no more and no less: + +1. `feature_source` +2. `download_dataset` +3. `temporal_aggregation` +4. `spatial_aggregation` +5. `build_datavalueset` + +Details: + +1. `feature_source` + - Resolves features from DHIS2 org unit level/ids or GeoJSON source config. + - Returns `FeatureCollection` and `bbox`. +2. `download_dataset` + - Runs connectivity preflight and downloads source data using `data_manager/services/downloader.py`. + - Supports request-supplied `country_code` where needed (e.g., WorldPop). +3. `temporal_aggregation` + - Loads/subsets data and performs period aggregation with selected reducer. +4. `spatial_aggregation` + - Aggregates gridded data over feature geometries. + - Produces normalized record rows (`org_unit`, `time`, `value`). +5. `build_datavalueset` + - Builds valid DHIS2 DataValueSet JSON from records. + - Serializes output to file and returns both payload and output path. + +`load_data` and `write_datavalueset` are intentionally not separate top-level components anymore; loading and writing are internalized within aggregation/build steps. + +Execution order is currently defined in: + +- `data/workflows/dhis2_datavalue_set.yaml` + +The default YAML remains the same 5-step sequence, but the engine now reads it declaratively. + +--- + +## Runtime Observability and Housekeeping + +File: `src/eo_api/workflows/services/runtime.py` + +For each component run, runtime captures: + +1. `component` +2. `status` +3. `started_at` +4. `ended_at` +5. `duration_ms` +6. `inputs` (summarized) +7. `outputs` (summarized) +8. `error` (on failure) + +Each workflow execution gets a unique `run_id`. + +### Response-Level Control of Run Details + +`include_component_run_details` controls response verbosity: + +1. If `false`: + - `component_runs: []` + - `component_run_details_included: false` + - `component_run_details_available: true` +2. If `true`: + - `component_runs` contains full per-component run records + - `component_run_details_included: true` + - `component_run_details_available: true` + +This keeps default responses clean while preserving debuggability when explicitly requested. + +--- + +## Run Logs + +File: `src/eo_api/workflows/services/run_logs.py` + +Workflow run logs are persisted under: + +- `/workflow_runs/` + +Persisted fields include: + +1. `run_id` +2. `status` (`completed` or `failed`) +3. normalized request payload +4. `component_runs` +5. output file path (when completed) +6. error details (when failed) + +--- + +## Error Handling Strategy + +1. `422` for request validation failures (Pydantic model constraints). +2. `404` when `dataset_id` does not exist in registry. +3. `503` for upstream connectivity issues detected during download/preflight: + - `error: "upstream_unreachable"` +4. `500` for other execution failures: + - `error: "workflow_execution_failed"` + +Failure responses include `failed_component` and `run_id` for traceability. + +--- + +## Achieved Behavior from Manual Verification + +Manual runs validated the following: + +1. WorldPop workflows now accept `country_code` from payload and execute without mandatory `.env` coupling. +2. Yearly dataset date normalization issues were resolved by period-aware mapping logic. +3. CHIRPS multi-month workflows execute correctly, with behavior improving as cache warms. +4. Workflow responses and run logs align with the 5-component chain. +5. Default response trimming works and detail flags remove ambiguity. + +Reference sample outputs: + +1. `docs/response/worldpop.json` +2. `docs/response/chirps3.json` +3. `docs/response/without_component_runs.json` +4. `docs/response/with_component_runs.json` + +--- + +## Testing and Quality Gates + +Primary tests: + +- `tests/test_workflows.py` + +Coverage includes: + +1. Single workflow endpoint behavior. +2. Payload validation and normalization paths. +3. Exact 5-component orchestration order. +4. Component detail include/exclude behavior. +5. Upstream connectivity error mapping. +6. Component catalog endpoint expectations. +7. Declarative workflow definition loading and default step validation. +8. Engine execution follows the definition-provided step order. + +Quality gates: + +1. `make lint` (ruff, mypy, pyright) +2. `uv run pytest -q` + +--- + +## Why This Approach + +This design intentionally balances: + +1. Simplicity for clients: + - one end-to-end endpoint and one public payload. +2. Generic dataset support: + - dataset-specific behavior comes from registry metadata and downloader wiring, not endpoint proliferation. +3. Reusability: + - component services are discoverable and executable independently. +4. Future orchestration readiness: + - component boundaries and run metadata are explicit, making Prefect/Airflow integration straightforward. + +--- + +## Sequence Diagram + +```mermaid +sequenceDiagram + autonumber + participant C as Client + participant R as /workflows route + participant M as simple_mapper + participant E as engine + participant RT as WorkflowRuntime + participant CS as components.services + participant RL as run_logs + + C->>R: POST /workflows/dhis2-datavalue-set (flat payload) + R->>M: normalize_simple_request(payload) + M-->>R: WorkflowExecuteRequest + R->>E: execute_workflow(request, include_component_run_details) + + E->>RT: run(feature_source) + RT->>CS: feature_source_component(...) + CS-->>RT: features, bbox + + E->>RT: run(download_dataset) + RT->>CS: download_dataset_component(...) + CS-->>RT: status + + E->>RT: run(temporal_aggregation) + RT->>CS: temporal_aggregation_component(...) + CS-->>RT: aggregated dataset + + E->>RT: run(spatial_aggregation) + RT->>CS: spatial_aggregation_component(...) + CS-->>RT: records + + E->>RT: run(build_datavalueset) + RT->>CS: build_datavalueset_component(...) + CS-->>RT: data_value_set, output_file + + E->>RL: persist_run_log(completed|failed) + RL-->>E: run_log_file + E-->>R: WorkflowExecuteResponse + R-->>C: 200 response (trimmed or detailed component runs) +``` + +Failure path: + +1. Any component exception is captured by runtime on the failing step. +2. Engine persists failed run log with `run_id` and `failed_component`. +3. Engine returns structured error: + - `503` with `error=upstream_unreachable` for connectivity failures. + - `500` with `error=workflow_execution_failed` for all other failures. + +--- + +## Manual E2E Testing + +Use the following commands to validate discovery, execution, and error behavior end-to-end. + +1. Start API: + +```bash +uvicorn eo_api.main:app --reload +``` + +2. Verify discovered workflows: + +```bash +curl -s http://127.0.0.1:8000/workflows | jq +``` + +3. Run default 5-step workflow: + +```bash +curl -s -X POST "http://127.0.0.1:8000/workflows/dhis2-datavalue-set" \ + -H "Content-Type: application/json" \ + -d '{ + "workflow_id": "dhis2_datavalue_set_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-02-29", + "org_unit_level": 2, + "data_element": "DE_UID", + "temporal_resolution": "monthly", + "temporal_reducer": "sum", + "spatial_reducer": "mean", + "include_component_run_details": true + }' | jq +``` + +Expected component order: + +1. `feature_source` +2. `download_dataset` +3. `temporal_aggregation` +4. `spatial_aggregation` +5. `build_datavalueset` + +4. Run 4-step workflow (without temporal aggregation): + +```bash +curl -s -X POST "http://127.0.0.1:8000/workflows/dhis2-datavalue-set" \ + -H "Content-Type: application/json" \ + -d '{ + "workflow_id": "dhis2_datavalue_set_without_temporal_aggregation_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-02-29", + "org_unit_level": 2, + "data_element": "DE_UID", + "spatial_reducer": "mean", + "include_component_run_details": true + }' | jq +``` + +Expected component order: + +1. `feature_source` +2. `download_dataset` +3. `spatial_aggregation` +4. `build_datavalueset` + +5. Negative test for unknown workflow: + +```bash +curl -s -X POST "http://127.0.0.1:8000/workflows/dhis2-datavalue-set" \ + -H "Content-Type: application/json" \ + -d '{ + "workflow_id": "does_not_exist", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "org_unit_level": 2, + "data_element": "DE_UID" + }' | jq +``` + +Expected result: `422` with allowed/discovered `workflow_id` values in error detail. + +--- + +## Next Technical Step + +Add a workflow governance model for multi-user environments: workflow metadata (owner/status), promotion states (draft/staging/prod), and optional signature/checksum validation before a discovered YAML can execute. From f30e84b2dc98d428e02862704820387392d3ce04 Mon Sep 17 00:00:00 2001 From: Abyot Asalefew Gizaw Date: Wed, 11 Mar 2026 23:40:10 +0100 Subject: [PATCH 3/5] Restore pygeoapi config and generated OpenAPI artifact for CI --- pygeoapi-config.yml | 40 +++++ pygeoapi-openapi.yml | 362 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 402 insertions(+) create mode 100644 pygeoapi-config.yml create mode 100644 pygeoapi-openapi.yml diff --git a/pygeoapi-config.yml b/pygeoapi-config.yml new file mode 100644 index 0000000..ca32423 --- /dev/null +++ b/pygeoapi-config.yml @@ -0,0 +1,40 @@ +server: + bind: + host: 0.0.0.0 + port: 5000 + url: http://127.0.0.1:8000/ogcapi + mimetype: application/json; charset=UTF-8 + encoding: utf-8 + languages: + - en-US + limits: + default_items: 20 + max_items: 50 + map: + url: https://tile.openstreetmap.org/{z}/{x}/{y}.png + attribution: OpenStreetMap + +metadata: + identification: + title: + en: DHIS2 EO API + description: + en: OGC API facade for EO services + keywords: + en: + - EO + - DHIS2 + terms_of_service: https://dhis2.org + url: https://dhis2.org + license: + name: CC-BY 4.0 + url: https://creativecommons.org/licenses/by/4.0/ + provider: + name: DHIS2 EO API + url: https://dhis2.org + contact: + name: DHIS2 + position: Team + email: climate@dhis2.org + +resources: {} diff --git a/pygeoapi-openapi.yml b/pygeoapi-openapi.yml new file mode 100644 index 0000000..d70eca1 --- /dev/null +++ b/pygeoapi-openapi.yml @@ -0,0 +1,362 @@ +components: + parameters: + bbox: + description: Only features that have a geometry that intersects the bounding + box are selected.The bounding box is provided as four or six numbers, depending + on whether the coordinate reference system includes a vertical axis (height + or depth). + explode: false + in: query + name: bbox + required: false + schema: + items: + type: number + maxItems: 6 + minItems: 4 + type: array + style: form + bbox-crs: + description: Indicates the coordinate reference system for the given bbox coordinates. + explode: false + in: query + name: bbox-crs + required: false + schema: + format: uri + type: string + style: form + bbox-crs-epsg: + description: Indicates the EPSG for the given bbox coordinates. + explode: false + in: query + name: bbox-crs + required: false + schema: + default: 4326 + type: integer + style: form + crs: + description: Indicates the coordinate reference system for the results. + explode: false + in: query + name: crs + required: false + schema: + format: uri + type: string + style: form + f: + description: The optional f parameter indicates the output format which the + server shall provide as part of the response document. The default format + is GeoJSON. + explode: false + in: query + name: f + required: false + schema: + default: json + enum: + - json + - html + - jsonld + type: string + style: form + lang: + description: The optional lang parameter instructs the server return a response + in a certain language, if supported. If the language is not among the available + values, the Accept-Language header language will be used if it is supported. + If the header is missing, the default server language is used. Note that providers + may only support a single language (or often no language at all), that can + be different from the server language. Language strings can be written in + a complex (e.g. "fr-CA,fr;q=0.9,en-US;q=0.8,en;q=0.7"), simple (e.g. "de") + or locale-like (e.g. "de-CH" or "fr_BE") fashion. + in: query + name: lang + required: false + schema: + default: en-US + enum: + - en-US + type: string + offset: + description: The optional offset parameter indicates the index within the result + set from which the server shall begin presenting results in the response document. The + first element has an index of 0 (default). + explode: false + in: query + name: offset + required: false + schema: + default: 0 + minimum: 0 + type: integer + style: form + resourceId: + description: Configuration resource identifier + in: path + name: resourceId + required: true + schema: + type: string + skipGeometry: + description: This option can be used to skip response geometries for each feature. + explode: false + in: query + name: skipGeometry + required: false + schema: + default: false + type: boolean + style: form + vendorSpecificParameters: + description: Additional "free-form" parameters that are not explicitly defined + in: query + name: vendorSpecificParameters + schema: + additionalProperties: true + type: object + style: form + responses: + '200': + description: successful operation + '204': + description: no content + Queryables: + content: + application/json: + schema: + $ref: '#/components/schemas/queryables' + description: successful queryables operation + default: + content: + application/json: + schema: + $ref: https://schemas.opengis.net/ogcapi/processes/part1/1.0/openapi/schemas/exception.yaml + description: Unexpected error + schemas: + queryable: + properties: + description: + description: a human-readable narrative describing the queryable + type: string + language: + default: en + description: the language used for the title and description + type: string + queryable: + description: the token that may be used in a CQL predicate + type: string + title: + description: a human readable title for the queryable + type: string + type: + description: the data type of the queryable + type: string + type-ref: + description: a reference to the formal definition of the type + format: url + type: string + required: + - queryable + - type + type: object + queryables: + properties: + queryables: + items: + $ref: '#/components/schemas/queryable' + type: array + required: + - queryables + type: object +info: + contact: + name: DHIS2 EO API + url: https://dhis2.org + x-ogc-serviceContact: + addresses: [] + emails: + - value: climate@dhis2.org + name: DHIS2 + position: Team + description: OGC API facade for EO services + license: + name: CC-BY 4.0 + url: https://creativecommons.org/licenses/by/4.0/ + termsOfService: https://dhis2.org + title: DHIS2 EO API + version: 0.22.0 + x-keywords: + - EO + - DHIS2 +openapi: 3.0.2 +paths: + /: + get: + description: Landing page + operationId: getLandingPage + parameters: + - $ref: '#/components/parameters/f' + - $ref: '#/components/parameters/lang' + responses: + '200': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/LandingPage + '400': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/InvalidParameter + '500': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/ServerError + summary: Landing page + tags: + - server + /collections: + get: + description: Collections + operationId: getCollections + parameters: + - $ref: '#/components/parameters/f' + - $ref: '#/components/parameters/lang' + responses: + '200': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/LandingPage + '400': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/InvalidParameter + '500': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/ServerError + summary: Collections + tags: + - server + /conformance: + get: + description: API conformance definition + operationId: getConformanceDeclaration + parameters: + - $ref: '#/components/parameters/f' + - $ref: '#/components/parameters/lang' + responses: + '200': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/LandingPage + '400': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/InvalidParameter + '500': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/ServerError + summary: API conformance definition + tags: + - server + /jobs: + get: + description: Retrieve a list of jobs + operationId: getJobs + responses: + '200': + $ref: '#/components/responses/200' + '404': + $ref: https://schemas.opengis.net/ogcapi/processes/part1/1.0/openapi/responses/NotFound.yaml + default: + $ref: '#/components/responses/default' + summary: Retrieve jobs list + tags: + - jobs + /jobs/{jobId}: + delete: + description: Cancel / delete job + operationId: deleteJob + parameters: + - &id001 + description: job identifier + in: path + name: jobId + required: true + schema: + type: string + responses: + '204': + $ref: '#/components/responses/204' + '404': + $ref: https://schemas.opengis.net/ogcapi/processes/part1/1.0/openapi/responses/NotFound.yaml + default: + $ref: '#/components/responses/default' + summary: Cancel / delete job + tags: + - jobs + get: + description: Retrieve job details + operationId: getJob + parameters: + - *id001 + - $ref: '#/components/parameters/f' + responses: + '200': + $ref: '#/components/responses/200' + '404': + $ref: https://schemas.opengis.net/ogcapi/processes/part1/1.0/openapi/responses/NotFound.yaml + default: + $ref: '#/components/responses/default' + summary: Retrieve job details + tags: + - jobs + /jobs/{jobId}/results: + get: + description: Retrieve job results + operationId: getJobResults + parameters: + - *id001 + - $ref: '#/components/parameters/f' + responses: + '200': + $ref: '#/components/responses/200' + '404': + $ref: https://schemas.opengis.net/ogcapi/processes/part1/1.0/openapi/responses/NotFound.yaml + default: + $ref: '#/components/responses/default' + summary: Retrieve job results + tags: + - jobs + /openapi: + get: + description: This document + operationId: getOpenapi + parameters: + - $ref: '#/components/parameters/f' + - $ref: '#/components/parameters/lang' + - description: UI to render the OpenAPI document + explode: false + in: query + name: ui + required: false + schema: + default: swagger + enum: + - swagger + - redoc + type: string + style: form + responses: + '200': + $ref: '#/components/responses/200' + '400': + $ref: https://schemas.opengis.net/ogcapi/features/part1/1.0/openapi/ogcapi-features-1.yaml#/components/responses/InvalidParameter + default: + $ref: '#/components/responses/default' + summary: This document + tags: + - server +servers: +- description: OGC API facade for EO services + url: http://127.0.0.1:8000/ogcapi +tags: +- description: OGC API facade for EO services + externalDocs: + description: information + url: https://dhis2.org + name: server +- name: coverages +- name: edr +- name: records +- name: features +- name: maps +- name: processes +- name: jobs +- name: tiles +- name: stac + From f616ce53e9983b91f7edacc044e309e5cb8d11c3 Mon Sep 17 00:00:00 2001 From: Abyot Asalefew Gizaw Date: Thu, 12 Mar 2026 08:57:33 +0100 Subject: [PATCH 4/5] Harden workflow engine with step configs, contracts, and error codes --- data/workflows/dhis2_datavalue_set.yaml | 5 + ...alue_set_without_temporal_aggregation.yaml | 4 + docs/workflow-orchestration.md | 60 ++- src/eo_api/components/schemas.py | 5 + src/eo_api/components/services.py | 189 +++++++-- src/eo_api/workflows/routes.py | 1 + src/eo_api/workflows/schemas.py | 1 - src/eo_api/workflows/services/definitions.py | 18 +- src/eo_api/workflows/services/engine.py | 359 ++++++++++++++---- src/eo_api/workflows/services/run_logs.py | 6 + .../workflows/services/simple_mapper.py | 9 +- tests/test_workflows.py | 201 ++++++++-- 12 files changed, 691 insertions(+), 167 deletions(-) diff --git a/data/workflows/dhis2_datavalue_set.yaml b/data/workflows/dhis2_datavalue_set.yaml index 180da1d..8f5afd8 100644 --- a/data/workflows/dhis2_datavalue_set.yaml +++ b/data/workflows/dhis2_datavalue_set.yaml @@ -2,7 +2,12 @@ workflow_id: dhis2_datavalue_set_v1 version: 1 steps: - component: feature_source + version: v1 - component: download_dataset + version: v1 - component: temporal_aggregation + version: v1 - component: spatial_aggregation + version: v1 - component: build_datavalueset + version: v1 diff --git a/data/workflows/dhis2_datavalue_set_without_temporal_aggregation.yaml b/data/workflows/dhis2_datavalue_set_without_temporal_aggregation.yaml index 6d1b1f0..c7baaf5 100644 --- a/data/workflows/dhis2_datavalue_set_without_temporal_aggregation.yaml +++ b/data/workflows/dhis2_datavalue_set_without_temporal_aggregation.yaml @@ -2,6 +2,10 @@ workflow_id: dhis2_datavalue_set_without_temporal_aggregation_v1 version: 1 steps: - component: feature_source + version: v1 - component: download_dataset + version: v1 - component: spatial_aggregation + version: v1 - component: build_datavalueset + version: v1 diff --git a/docs/workflow-orchestration.md b/docs/workflow-orchestration.md index 22315c8..97d26a1 100644 --- a/docs/workflow-orchestration.md +++ b/docs/workflow-orchestration.md @@ -34,6 +34,9 @@ The current implementation provides: 8. Optional inclusion of detailed component run traces in API responses. 9. Discoverable standalone component endpoints under `/components` for direct execution and future orchestrator integration. 10. Declarative workflow assembly via YAML (`data/workflows/dhis2_datavalue_set.yaml`) executed by the workflow engine. +11. Registry-driven component dispatch in engine (no component-specific `if/elif` chain). +12. Step-level YAML config support with strict validation and `$request.` interpolation. +13. Stable workflow error contract with `error_code` and `failed_component_version`. --- @@ -96,7 +99,6 @@ Notes: 1. `feature_id_property` defaults to `"id"` and controls which feature property maps to DHIS2 org unit ID in spatial aggregation/DataValueSet construction. 2. `country_code` is accepted in request and passed to dataset downloaders (instead of forcing `.env` only). -3. `reducer` is accepted as an alias and mapped to both temporal and spatial reducer when provided. --- @@ -111,9 +113,9 @@ Public flat payload is normalized to internal `WorkflowExecuteRequest` with comp - `org_unit_ids` -> `source_type=dhis2_ids` 2. `temporal_aggregation` config: - `target_period_type` from `temporal_resolution` - - `method` from `temporal_reducer` (or `reducer` alias) + - `method` from `temporal_reducer` 3. `spatial_aggregation` config: - - `method` from `spatial_reducer` (or `reducer` alias) + - `method` from `spatial_reducer` 4. `dhis2` config: - `data_element_uid` from `data_element` @@ -169,9 +171,10 @@ Responsibilities: 1. Discover, load, and validate declarative workflow definitions from `data/workflows/*.yaml`. 2. Enforce supported component names. -3. Enforce terminal `build_datavalueset` step for this end-to-end workflow. -4. Enforce output-to-input compatibility across the full accumulated context (not just adjacent steps). -5. Drive runtime execution order from YAML instead of hardcoded sequence. +3. Enforce supported component versions (currently `v1`) and validate per-step `config`. +4. Enforce terminal `build_datavalueset` step for this end-to-end workflow. +5. Enforce output-to-input compatibility across the full accumulated context (not just adjacent steps). +6. Drive runtime execution order from YAML through a registry-dispatch model. ### Reusable Component Service Layer @@ -238,13 +241,17 @@ Details: - Builds valid DHIS2 DataValueSet JSON from records. - Serializes output to file and returns both payload and output path. -`load_data` and `write_datavalueset` are intentionally not separate top-level components anymore; loading and writing are internalized within aggregation/build steps. - -Execution order is currently defined in: +Execution order and step metadata are currently defined in: - `data/workflows/dhis2_datavalue_set.yaml` -The default YAML remains the same 5-step sequence, but the engine now reads it declaratively. +Workflow step schema now supports: + +1. `component` +2. `version` (default `v1`) +3. `config` (default `{}`) + +The default YAML remains the same 5-step sequence, but the engine reads it declaratively and dispatches components through a registry map. --- @@ -298,38 +305,29 @@ Persisted fields include: 4. `component_runs` 5. output file path (when completed) 6. error details (when failed) +7. `error_code` (when failed) +8. `failed_component` (when failed) +9. `failed_component_version` (when failed) --- ## Error Handling Strategy -1. `422` for request validation failures (Pydantic model constraints). +1. `422` for request/definition/config validation failures. 2. `404` when `dataset_id` does not exist in registry. -3. `503` for upstream connectivity issues detected during download/preflight: +3. `503` for upstream connectivity failures: - `error: "upstream_unreachable"` + - `error_code: "UPSTREAM_UNREACHABLE"` 4. `500` for other execution failures: - `error: "workflow_execution_failed"` + - `error_code: "EXECUTION_FAILED"` (or other stable mapped codes) -Failure responses include `failed_component` and `run_id` for traceability. - ---- - -## Achieved Behavior from Manual Verification - -Manual runs validated the following: - -1. WorldPop workflows now accept `country_code` from payload and execute without mandatory `.env` coupling. -2. Yearly dataset date normalization issues were resolved by period-aware mapping logic. -3. CHIRPS multi-month workflows execute correctly, with behavior improving as cache warms. -4. Workflow responses and run logs align with the 5-component chain. -5. Default response trimming works and detail flags remove ambiguity. - -Reference sample outputs: +Failure responses include: -1. `docs/response/worldpop.json` -2. `docs/response/chirps3.json` -3. `docs/response/without_component_runs.json` -4. `docs/response/with_component_runs.json` +1. `error_code` +2. `failed_component` +3. `failed_component_version` +4. `run_id` --- diff --git a/src/eo_api/components/schemas.py b/src/eo_api/components/schemas.py index 80c772e..a0920dc 100644 --- a/src/eo_api/components/schemas.py +++ b/src/eo_api/components/schemas.py @@ -18,9 +18,14 @@ class ComponentDefinition(BaseModel): """Component metadata for discovery.""" name: str + version: str = "v1" description: str inputs: list[str] outputs: list[str] + input_schema: dict[str, Any] = Field(default_factory=dict) + config_schema: dict[str, Any] = Field(default_factory=dict) + output_schema: dict[str, Any] = Field(default_factory=dict) + error_codes: list[str] = Field(default_factory=list) class ComponentCatalogResponse(BaseModel): diff --git a/src/eo_api/components/services.py b/src/eo_api/components/services.py index 2e99133..8bd0e1d 100644 --- a/src/eo_api/components/services.py +++ b/src/eo_api/components/services.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Final import xarray as xr from fastapi import HTTPException @@ -23,41 +23,164 @@ from ..workflows.services.temporal import aggregate_temporal from .schemas import ComponentDefinition +_ERROR_CODES_V1: Final[list[str]] = [ + "INPUT_VALIDATION_FAILED", + "CONFIG_VALIDATION_FAILED", + "OUTPUT_VALIDATION_FAILED", + "UPSTREAM_UNREACHABLE", + "EXECUTION_FAILED", +] + +_COMPONENT_REGISTRY: Final[dict[str, ComponentDefinition]] = { + "feature_source@v1": ComponentDefinition( + name="feature_source", + version="v1", + description="Resolve feature source and compute bbox.", + inputs=["feature_source"], + outputs=["features", "bbox"], + input_schema={ + "type": "object", + "properties": {"feature_source": {"type": "object"}}, + "required": ["feature_source"], + }, + config_schema={"type": "object", "properties": {}, "additionalProperties": False}, + output_schema={ + "type": "object", + "properties": { + "features": {"type": "object"}, + "bbox": {"type": "array", "items": {"type": "number"}, "minItems": 4, "maxItems": 4}, + }, + "required": ["features", "bbox"], + }, + error_codes=_ERROR_CODES_V1, + ), + "download_dataset@v1": ComponentDefinition( + name="download_dataset", + version="v1", + description="Download dataset files for period and bbox.", + inputs=["dataset_id", "start", "end", "overwrite", "country_code", "bbox"], + outputs=["status"], + input_schema={ + "type": "object", + "properties": { + "dataset_id": {"type": "string"}, + "start": {"type": "string"}, + "end": {"type": "string"}, + "overwrite": {"type": "boolean"}, + "country_code": {"type": ["string", "null"]}, + "bbox": {"type": "array", "items": {"type": "number"}, "minItems": 4, "maxItems": 4}, + }, + "required": ["dataset_id", "start", "end", "overwrite", "bbox"], + }, + config_schema={ + "type": "object", + "properties": { + "overwrite": {"type": "boolean"}, + "country_code": {"type": ["string", "null"]}, + }, + "additionalProperties": False, + }, + output_schema={"type": "object", "properties": {"status": {"type": "string"}}}, + error_codes=_ERROR_CODES_V1, + ), + "temporal_aggregation@v1": ComponentDefinition( + name="temporal_aggregation", + version="v1", + description="Aggregate dataset over time dimension.", + inputs=["dataset_id", "start", "end", "target_period_type", "method", "bbox"], + outputs=["dataset"], + input_schema={ + "type": "object", + "properties": { + "dataset_id": {"type": "string"}, + "start": {"type": "string"}, + "end": {"type": "string"}, + "target_period_type": {"type": "string"}, + "method": {"type": "string"}, + "bbox": {"type": ["array", "null"], "items": {"type": "number"}}, + }, + "required": ["dataset_id", "start", "end", "target_period_type", "method"], + }, + config_schema={ + "type": "object", + "properties": { + "target_period_type": {"type": "string"}, + "method": {"type": "string"}, + }, + "additionalProperties": False, + }, + output_schema={"type": "object", "properties": {"dataset": {"type": "object"}}}, + error_codes=_ERROR_CODES_V1, + ), + "spatial_aggregation@v1": ComponentDefinition( + name="spatial_aggregation", + version="v1", + description="Aggregate gridded dataset to features.", + inputs=["dataset_id", "start", "end", "feature_source", "method"], + outputs=["records"], + input_schema={ + "type": "object", + "properties": { + "dataset_id": {"type": "string"}, + "start": {"type": "string"}, + "end": {"type": "string"}, + "feature_source": {"type": "object"}, + "method": {"type": "string"}, + }, + "required": ["dataset_id", "start", "end", "feature_source", "method"], + }, + config_schema={ + "type": "object", + "properties": { + "method": {"type": "string"}, + "feature_id_property": {"type": "string"}, + }, + "additionalProperties": False, + }, + output_schema={"type": "object", "properties": {"records": {"type": "array"}}}, + error_codes=_ERROR_CODES_V1, + ), + "build_datavalueset@v1": ComponentDefinition( + name="build_datavalueset", + version="v1", + description="Build and serialize DHIS2 DataValueSet JSON.", + inputs=["dataset_id", "period_type", "records", "dhis2"], + outputs=["data_value_set", "output_file"], + input_schema={ + "type": "object", + "properties": { + "dataset_id": {"type": "string"}, + "period_type": {"type": "string"}, + "records": {"type": "array"}, + "dhis2": {"type": "object"}, + }, + "required": ["dataset_id", "period_type", "records", "dhis2"], + }, + config_schema={ + "type": "object", + "properties": { + "period_type": {"type": "string"}, + }, + "additionalProperties": False, + }, + output_schema={ + "type": "object", + "properties": {"data_value_set": {"type": "object"}, "output_file": {"type": "string"}}, + "required": ["data_value_set", "output_file"], + }, + error_codes=_ERROR_CODES_V1, + ), +} + def component_catalog() -> list[ComponentDefinition]: """Return all discoverable component definitions.""" - return [ - ComponentDefinition( - name="feature_source", - description="Resolve feature source and compute bbox.", - inputs=["feature_source"], - outputs=["features", "bbox"], - ), - ComponentDefinition( - name="download_dataset", - description="Download dataset files for period and bbox.", - inputs=["dataset_id", "start", "end", "overwrite", "country_code", "bbox"], - outputs=["status"], - ), - ComponentDefinition( - name="temporal_aggregation", - description="Aggregate dataset over time dimension.", - inputs=["dataset_id", "start", "end", "target_period_type", "method", "bbox"], - outputs=["dataset"], - ), - ComponentDefinition( - name="spatial_aggregation", - description="Aggregate gridded dataset to features.", - inputs=["dataset_id", "start", "end", "feature_source", "method"], - outputs=["records"], - ), - ComponentDefinition( - name="build_datavalueset", - description="Build and serialize DHIS2 DataValueSet JSON.", - inputs=["dataset_id", "period_type", "records", "dhis2"], - outputs=["data_value_set", "output_file"], - ), - ] + return list(_COMPONENT_REGISTRY.values()) + + +def component_registry() -> dict[str, ComponentDefinition]: + """Return registry entries keyed by component@version.""" + return dict(_COMPONENT_REGISTRY) def feature_source_component(config: FeatureSourceConfig) -> tuple[dict[str, Any], list[float]]: diff --git a/src/eo_api/workflows/routes.py b/src/eo_api/workflows/routes.py index da9d271..6716c51 100644 --- a/src/eo_api/workflows/routes.py +++ b/src/eo_api/workflows/routes.py @@ -37,5 +37,6 @@ def run_dhis2_datavalue_set_workflow(payload: WorkflowRequest) -> WorkflowExecut return execute_workflow( request, workflow_id=payload.workflow_id, + request_params=payload.model_dump(), include_component_run_details=payload.include_component_run_details, ) diff --git a/src/eo_api/workflows/schemas.py b/src/eo_api/workflows/schemas.py index 7e48c5f..cfd0645 100644 --- a/src/eo_api/workflows/schemas.py +++ b/src/eo_api/workflows/schemas.py @@ -161,7 +161,6 @@ class WorkflowRequest(BaseModel): flavor: str | None = None country_code: str | None = None output_format: str | None = None - reducer: str | None = None include_component_run_details: bool = False @model_validator(mode="after") diff --git a/src/eo_api/workflows/services/definitions.py b/src/eo_api/workflows/services/definitions.py index 9876406..a05965c 100644 --- a/src/eo_api/workflows/services/definitions.py +++ b/src/eo_api/workflows/services/definitions.py @@ -3,10 +3,10 @@ from __future__ import annotations from pathlib import Path -from typing import Final, Literal +from typing import Any, Final, Literal import yaml -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, model_validator ComponentName = Literal[ "feature_source", @@ -17,6 +17,7 @@ ] SUPPORTED_COMPONENTS: Final[set[str]] = set(ComponentName.__args__) # type: ignore[attr-defined] +SUPPORTED_COMPONENT_VERSIONS: Final[dict[str, set[str]]] = {component: {"v1"} for component in SUPPORTED_COMPONENTS} COMPONENT_INPUTS: Final[dict[str, set[str]]] = { "feature_source": set(), @@ -43,6 +44,19 @@ class WorkflowStep(BaseModel): """One component step in a declarative workflow definition.""" component: ComponentName + version: str = "v1" + config: dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def validate_component_version(self) -> "WorkflowStep": + """Ensure component@version exists in the registered component catalog.""" + supported_versions = SUPPORTED_COMPONENT_VERSIONS.get(self.component, set()) + if self.version not in supported_versions: + known = ", ".join(sorted(supported_versions)) or "" + raise ValueError( + f"Unsupported component version '{self.component}@{self.version}'. Supported versions: {known}" + ) + return self class WorkflowDefinition(BaseModel): diff --git a/src/eo_api/workflows/services/engine.py b/src/eo_api/workflows/services/engine.py index 9e7da9e..d8e1ad7 100644 --- a/src/eo_api/workflows/services/engine.py +++ b/src/eo_api/workflows/services/engine.py @@ -2,22 +2,44 @@ from __future__ import annotations +from collections.abc import Callable from typing import Any from fastapi import HTTPException +from pydantic import BaseModel, ConfigDict, ValidationError from ...components import services as component_services from ...data_registry.services.datasets import get_dataset -from ..schemas import WorkflowExecuteRequest, WorkflowExecuteResponse +from ..schemas import AggregationMethod, PeriodType, WorkflowExecuteRequest, WorkflowExecuteResponse from .definitions import WorkflowDefinition, load_workflow_definition from .run_logs import persist_run_log from .runtime import WorkflowRuntime +class WorkflowComponentError(RuntimeError): + """Typed component failure with stable error code and component context.""" + + def __init__( + self, + *, + error_code: str, + message: str, + component: str, + component_version: str, + status_code: int, + ) -> None: + super().__init__(message) + self.error_code = error_code + self.component = component + self.component_version = component_version + self.status_code = status_code + + def execute_workflow( request: WorkflowExecuteRequest, *, workflow_id: str = "dhis2_datavalue_set_v1", + request_params: dict[str, Any] | None = None, include_component_run_details: bool = False, ) -> WorkflowExecuteResponse: """Execute the feature->download->aggregate->DataValueSet workflow.""" @@ -38,6 +60,7 @@ def execute_workflow( workflow=workflow, runtime=runtime, request=request, + request_params=request_params, dataset=dataset, context=context, ) @@ -69,6 +92,29 @@ def execute_workflow( component_run_details_included=include_component_run_details, component_run_details_available=True, ) + except WorkflowComponentError as exc: + persist_run_log( + run_id=runtime.run_id, + request=request, + component_runs=runtime.component_runs, + status="failed", + error=str(exc), + error_code=exc.error_code, + failed_component=exc.component, + failed_component_version=exc.component_version, + ) + error = "upstream_unreachable" if exc.error_code == "UPSTREAM_UNREACHABLE" else "workflow_execution_failed" + raise HTTPException( + status_code=exc.status_code, + detail={ + "error": error, + "error_code": exc.error_code, + "message": str(exc), + "failed_component": exc.component, + "failed_component_version": exc.component_version, + "run_id": runtime.run_id, + }, + ) from exc except HTTPException: persist_run_log( run_id=runtime.run_id, @@ -85,24 +131,17 @@ def execute_workflow( component_runs=runtime.component_runs, status="failed", error=str(exc), + error_code="EXECUTION_FAILED", ) last_component = runtime.component_runs[-1].component if runtime.component_runs else "unknown" - if _is_upstream_connectivity_error(exc): - raise HTTPException( - status_code=503, - detail={ - "error": "upstream_unreachable", - "message": "Could not reach upstream data source. Check network/proxy and retry.", - "failed_component": last_component, - "run_id": runtime.run_id, - }, - ) from exc raise HTTPException( status_code=500, detail={ "error": "workflow_execution_failed", + "error_code": "EXECUTION_FAILED", "message": str(exc), "failed_component": last_component, + "failed_component_version": "unknown", "run_id": runtime.run_id, }, ) from exc @@ -127,77 +166,180 @@ def _execute_workflow_steps( workflow: WorkflowDefinition, runtime: WorkflowRuntime, request: WorkflowExecuteRequest, + request_params: dict[str, Any] | None, dataset: dict[str, Any], context: dict[str, Any], ) -> None: """Execute workflow components using declarative YAML step order.""" + executors: dict[str, StepExecutor] = { + "feature_source": _run_feature_source, + "download_dataset": _run_download_dataset, + "temporal_aggregation": _run_temporal_aggregation, + "spatial_aggregation": _run_spatial_aggregation, + "build_datavalueset": _run_build_datavalueset, + } + for step in workflow.steps: - if step.component == "feature_source": - features, bbox = runtime.run( - "feature_source", - component_services.feature_source_component, - config=request.feature_source, + executor = executors.get(step.component) + if executor is None: + raise WorkflowComponentError( + error_code="INPUT_VALIDATION_FAILED", + message=f"Unsupported workflow component '{step.component}'", + component=step.component, + component_version=step.version, + status_code=422, ) - context["features"] = features - context["bbox"] = bbox - continue - - if step.component == "download_dataset": - runtime.run( - "download_dataset", - component_services.download_dataset_component, - dataset=dataset, - start=request.start, - end=request.end, - overwrite=request.overwrite, - country_code=request.country_code, - bbox=_require_context(context, "bbox"), - ) - continue + try: + step_config = _resolve_step_config(step.config, request_params or {}) + _validate_step_config(step.component, step.version, step_config) + except ValueError as exc: + raise WorkflowComponentError( + error_code="CONFIG_VALIDATION_FAILED", + message=str(exc), + component=step.component, + component_version=step.version, + status_code=422, + ) from exc - if step.component == "temporal_aggregation": - temporal_ds = runtime.run( - "temporal_aggregation", - component_services.temporal_aggregation_component, + try: + updates = executor( + runtime=runtime, + request=request, dataset=dataset, - start=request.start, - end=request.end, - bbox=_require_context(context, "bbox"), - target_period_type=request.temporal_aggregation.target_period_type, - method=request.temporal_aggregation.method, + context=context, + step_config=step_config, ) - context["temporal_dataset"] = temporal_ds - continue + except Exception as exc: + if _is_upstream_connectivity_error(exc): + raise WorkflowComponentError( + error_code="UPSTREAM_UNREACHABLE", + message="Could not reach upstream data source. Check network/proxy and retry.", + component=step.component, + component_version=step.version, + status_code=503, + ) from exc + raise WorkflowComponentError( + error_code="EXECUTION_FAILED", + message=str(exc), + component=step.component, + component_version=step.version, + status_code=500, + ) from exc - if step.component == "spatial_aggregation": - records = runtime.run( - "spatial_aggregation", - component_services.spatial_aggregation_component, - dataset=dataset, - start=request.start, - end=request.end, - bbox=_require_context(context, "bbox"), - features=_require_context(context, "features"), - method=request.spatial_aggregation.method, - feature_id_property=request.dhis2.org_unit_property, - ) - context["records"] = records - continue - - if step.component == "build_datavalueset": - data_value_set, output_file = runtime.run( - "build_datavalueset", - component_services.build_datavalueset_component, - records=_require_context(context, "records"), - dataset_id=request.dataset_id, - period_type=request.temporal_aggregation.target_period_type, - dhis2=request.dhis2, - ) - context["data_value_set"] = data_value_set - context["output_file"] = output_file - continue + context.update(updates) + + +type StepExecutor = Callable[..., dict[str, Any]] - raise RuntimeError(f"Unsupported workflow component '{step.component}'") + +def _run_feature_source( + *, + runtime: WorkflowRuntime, + request: WorkflowExecuteRequest, + dataset: dict[str, Any], + context: dict[str, Any], + step_config: dict[str, Any], +) -> dict[str, Any]: + del dataset, context, step_config + features, bbox = runtime.run( + "feature_source", + component_services.feature_source_component, + config=request.feature_source, + ) + return {"features": features, "bbox": bbox} + + +def _run_download_dataset( + *, + runtime: WorkflowRuntime, + request: WorkflowExecuteRequest, + dataset: dict[str, Any], + context: dict[str, Any], + step_config: dict[str, Any], +) -> dict[str, Any]: + overwrite = bool(step_config.get("overwrite", request.overwrite)) + country_code = step_config.get("country_code", request.country_code) + runtime.run( + "download_dataset", + component_services.download_dataset_component, + dataset=dataset, + start=request.start, + end=request.end, + overwrite=overwrite, + country_code=country_code, + bbox=_require_context(context, "bbox"), + ) + return {} + + +def _run_temporal_aggregation( + *, + runtime: WorkflowRuntime, + request: WorkflowExecuteRequest, + dataset: dict[str, Any], + context: dict[str, Any], + step_config: dict[str, Any], +) -> dict[str, Any]: + target_period_type = PeriodType( + str(step_config.get("target_period_type", request.temporal_aggregation.target_period_type)) + ) + method = AggregationMethod(str(step_config.get("method", request.temporal_aggregation.method))) + temporal_ds = runtime.run( + "temporal_aggregation", + component_services.temporal_aggregation_component, + dataset=dataset, + start=request.start, + end=request.end, + bbox=_require_context(context, "bbox"), + target_period_type=target_period_type, + method=method, + ) + return {"temporal_dataset": temporal_ds} + + +def _run_spatial_aggregation( + *, + runtime: WorkflowRuntime, + request: WorkflowExecuteRequest, + dataset: dict[str, Any], + context: dict[str, Any], + step_config: dict[str, Any], +) -> dict[str, Any]: + method = AggregationMethod(str(step_config.get("method", request.spatial_aggregation.method))) + feature_id_property = str(step_config.get("feature_id_property", request.dhis2.org_unit_property)) + records = runtime.run( + "spatial_aggregation", + component_services.spatial_aggregation_component, + dataset=dataset, + start=request.start, + end=request.end, + bbox=_require_context(context, "bbox"), + features=_require_context(context, "features"), + method=method, + feature_id_property=feature_id_property, + ) + return {"records": records} + + +def _run_build_datavalueset( + *, + runtime: WorkflowRuntime, + request: WorkflowExecuteRequest, + dataset: dict[str, Any], + context: dict[str, Any], + step_config: dict[str, Any], +) -> dict[str, Any]: + del dataset + period_type = PeriodType(str(step_config.get("period_type", request.temporal_aggregation.target_period_type))) + data_value_set, output_file = runtime.run( + "build_datavalueset", + component_services.build_datavalueset_component, + records=_require_context(context, "records"), + dataset_id=request.dataset_id, + period_type=period_type, + dhis2=request.dhis2, + ) + return {"data_value_set": data_value_set, "output_file": output_file} def _require_context(context: dict[str, Any], key: str) -> Any: @@ -205,3 +347,78 @@ def _require_context(context: dict[str, Any], key: str) -> Any: if key not in context: raise RuntimeError(f"Workflow definition missing prerequisite for '{key}'") return context[key] + + +def _resolve_step_config(config: dict[str, Any], request_params: dict[str, Any]) -> dict[str, Any]: + """Resolve $request. tokens in step config.""" + resolved: dict[str, Any] = {} + for key, value in config.items(): + resolved[key] = _resolve_value(value, request_params) + return resolved + + +def _resolve_value(value: Any, request_params: dict[str, Any]) -> Any: + """Resolve a config value recursively.""" + if isinstance(value, str) and value.startswith("$request."): + field = value.removeprefix("$request.") + if field not in request_params: + raise ValueError(f"Unknown request field in config token: {value}") + return request_params[field] + if isinstance(value, dict): + return {k: _resolve_value(v, request_params) for k, v in value.items()} + if isinstance(value, list): + return [_resolve_value(v, request_params) for v in value] + return value + + +class _FeatureSourceStepConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class _DownloadDatasetStepConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + overwrite: bool | None = None + country_code: str | None = None + + +class _TemporalAggregationStepConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + target_period_type: PeriodType | None = None + method: AggregationMethod | None = None + + +class _SpatialAggregationStepConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + method: AggregationMethod | None = None + feature_id_property: str | None = None + + +class _BuildDataValueSetStepConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + period_type: PeriodType | None = None + + +_STEP_CONFIG_MODELS: dict[str, type[BaseModel]] = { + "feature_source": _FeatureSourceStepConfig, + "download_dataset": _DownloadDatasetStepConfig, + "temporal_aggregation": _TemporalAggregationStepConfig, + "spatial_aggregation": _SpatialAggregationStepConfig, + "build_datavalueset": _BuildDataValueSetStepConfig, +} + + +def _validate_step_config(component: str, version: str, config: dict[str, Any]) -> None: + """Validate step config with strict Pydantic models.""" + if version != "v1": + raise ValueError(f"Unsupported component version for config validation: {component}@{version}") + model = _STEP_CONFIG_MODELS.get(component) + if model is None: + raise ValueError(f"No config schema registered for component '{component}'") + try: + model.model_validate(config) + except ValidationError as exc: + raise ValueError(f"Invalid config for component '{component}@{version}': {exc}") from exc diff --git a/src/eo_api/workflows/services/run_logs.py b/src/eo_api/workflows/services/run_logs.py index ea4d375..aa9fef2 100644 --- a/src/eo_api/workflows/services/run_logs.py +++ b/src/eo_api/workflows/services/run_logs.py @@ -18,6 +18,9 @@ def persist_run_log( status: str, output_file: str | None = None, error: str | None = None, + error_code: str | None = None, + failed_component: str | None = None, + failed_component_version: str | None = None, ) -> str: """Write workflow run metadata to disk and return file path.""" logs_dir = DOWNLOAD_DIR / "workflow_runs" @@ -32,6 +35,9 @@ def persist_run_log( "component_runs": [run.model_dump(mode="json") for run in component_runs], "output_file": output_file, "error": error, + "error_code": error_code, + "failed_component": failed_component, + "failed_component_version": failed_component_version, } path.write_text(json.dumps(payload, indent=2), encoding="utf-8") return str(path) diff --git a/src/eo_api/workflows/services/simple_mapper.py b/src/eo_api/workflows/services/simple_mapper.py index 74b8758..e8c016f 100644 --- a/src/eo_api/workflows/services/simple_mapper.py +++ b/src/eo_api/workflows/services/simple_mapper.py @@ -6,7 +6,6 @@ from ...data_registry.services.datasets import get_dataset from ..schemas import ( - AggregationMethod, Dhis2DataValueSetConfig, FeatureSourceConfig, FeatureSourceType, @@ -63,10 +62,6 @@ def normalize_simple_request(payload: WorkflowRequest) -> tuple[WorkflowExecuteR else: raise HTTPException(status_code=422, detail="Provide org_unit_level or org_unit_ids") - reducer_alias = AggregationMethod(inputs.reducer.lower()) if inputs.reducer else None - spatial_method = reducer_alias or inputs.spatial_reducer - temporal_method = reducer_alias or inputs.temporal_reducer - normalized = WorkflowExecuteRequest( dataset_id=dataset_id, start=start, @@ -76,9 +71,9 @@ def normalize_simple_request(payload: WorkflowRequest) -> tuple[WorkflowExecuteR feature_source=feature_source, temporal_aggregation=TemporalAggregationConfig( target_period_type=inputs.temporal_resolution, - method=temporal_method, + method=inputs.temporal_reducer, ), - spatial_aggregation=SpatialAggregationConfig(method=spatial_method), + spatial_aggregation=SpatialAggregationConfig(method=inputs.spatial_reducer), dhis2=Dhis2DataValueSetConfig(data_element_uid=inputs.data_element), ) diff --git a/tests/test_workflows.py b/tests/test_workflows.py index da16352..4056629 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast import pytest import xarray as xr @@ -82,6 +82,12 @@ def test_components_catalog_endpoint_returns_five_components(client: TestClient) "spatial_aggregation", "build_datavalueset", } + for item in items: + assert item["version"] == "v1" + assert isinstance(item["input_schema"], dict) + assert isinstance(item["config_schema"], dict) + assert isinstance(item["output_schema"], dict) + assert "EXECUTION_FAILED" in item["error_codes"] def test_workflow_endpoint_returns_response_shape(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: @@ -110,9 +116,19 @@ def test_workflow_endpoint_returns_response_shape(client: TestClient, monkeypatc }, component_runs=[], ) + + def _execute_stub( + payload: Any, + workflow_id: str = "dhis2_datavalue_set_v1", + request_params: dict[str, Any] | None = None, + include_component_run_details: bool = False, + ) -> WorkflowExecuteResponse: + del payload, workflow_id, request_params, include_component_run_details + return stub + monkeypatch.setattr( "eo_api.workflows.routes.execute_workflow", - lambda payload, workflow_id="dhis2_datavalue_set_v1", include_component_run_details=False: stub, + _execute_stub, ) response = client.post("/workflows/dhis2-datavalue-set", json=_valid_public_payload()) @@ -162,10 +178,20 @@ def test_workflow_endpoint_accepts_simplified_payload(client: TestClient, monkey data_value_set={"dataValues": []}, component_runs=[], ) + + def _execute_stub( + payload: Any, + workflow_id: str = "dhis2_datavalue_set_v1", + request_params: dict[str, Any] | None = None, + include_component_run_details: bool = False, + ) -> WorkflowExecuteResponse: + del payload, workflow_id, request_params, include_component_run_details + return stub + monkeypatch.setattr("eo_api.workflows.routes.normalize_simple_request", lambda payload: (normalized, [])) monkeypatch.setattr( "eo_api.workflows.routes.execute_workflow", - lambda payload, workflow_id="dhis2_datavalue_set_v1", include_component_run_details=False: stub, + _execute_stub, ) response = client.post("/workflows/dhis2-datavalue-set", json=_valid_public_payload()) @@ -328,6 +354,11 @@ def test_engine_returns_503_when_upstream_unreachable(monkeypatch: pytest.Monkey engine.execute_workflow(request) assert exc_info.value.status_code == 503 + detail = cast(dict[str, Any], exc_info.value.detail) + assert detail["error"] == "upstream_unreachable" + assert detail["error_code"] == "UPSTREAM_UNREACHABLE" + assert detail["failed_component"] == "download_dataset" + assert detail["failed_component_version"] == "v1" def test_mapper_uses_year_format_for_yearly_dataset() -> None: @@ -348,25 +379,6 @@ def test_mapper_uses_year_format_for_yearly_dataset() -> None: assert normalized.end == "2026" -def test_mapper_reducer_alias_overrides_spatial_and_temporal_reducers() -> None: - normalized, _warnings = normalize_simple_request( - WorkflowRequest.model_validate( - { - "dataset_id": "worldpop_population_yearly", - "country_code": "SLE", - "start_year": 2015, - "end_year": 2026, - "org_unit_level": 2, - "data_element": "DE_UID", - "temporal_resolution": "yearly", - "reducer": "sum", - } - ) - ) - assert normalized.spatial_aggregation.method.value == "sum" - assert normalized.temporal_aggregation.method.value == "sum" - - def test_mapper_uses_month_format_for_chirps_date_window() -> None: normalized, _warnings = normalize_simple_request( WorkflowRequest.model_validate( @@ -488,3 +500,148 @@ def test_engine_rejects_unknown_workflow_id(monkeypatch: pytest.MonkeyPatch) -> engine.execute_workflow(request, workflow_id="not_allowlisted") assert exc_info.value.status_code == 422 + + +def test_engine_resolves_step_config_from_request_params(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + ds = xr.Dataset( + {"precip": (("time", "lat", "lon"), [[[1.0]]])}, + coords={"time": ["2024-01-01"], "lat": [0], "lon": [0]}, + ) + + monkeypatch.setattr( + engine, + "load_workflow_definition", + lambda workflow_id: WorkflowDefinition.model_validate( + { + "workflow_id": workflow_id, + "version": 2, + "steps": [ + {"component": "feature_source"}, + {"component": "download_dataset"}, + { + "component": "temporal_aggregation", + "config": { + "method": "$request.temporal_reducer", + "target_period_type": "$request.temporal_resolution", + }, + }, + {"component": "spatial_aggregation"}, + {"component": "build_datavalueset"}, + ], + } + ), + ) + monkeypatch.setattr( + engine, + "get_dataset", + lambda dataset_id: {"id": "chirps3_precipitation_daily", "variable": "precip"}, + ) + monkeypatch.setattr( + engine.component_services, + "feature_source_component", + lambda config: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0, 0, 1, 1], + ), + ) + monkeypatch.setattr(engine.component_services, "download_dataset_component", lambda **kwargs: None) + + def _temporal_component(**kwargs: Any) -> xr.Dataset: + assert kwargs["method"].value == "max" + assert kwargs["target_period_type"].value == "monthly" + return ds + + monkeypatch.setattr(engine.component_services, "temporal_aggregation_component", _temporal_component) + monkeypatch.setattr( + engine.component_services, + "spatial_aggregation_component", + lambda **kwargs: [{"org_unit": "OU_1", "time": "2024-01-01", "value": 10.0}], + ) + monkeypatch.setattr( + engine.component_services, + "build_datavalueset_component", + lambda **kwargs: ({"dataValues": [{"value": "10.0"}]}, "/tmp/data/out.json"), + ) + monkeypatch.setattr(engine, "persist_run_log", lambda **kwargs: "/tmp/data/workflow_runs/run.json") + + response = engine.execute_workflow( + request, + request_params={"temporal_reducer": "max", "temporal_resolution": "monthly"}, + ) + assert response.status == "completed" + + +def test_engine_rejects_invalid_step_config(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + monkeypatch.setattr( + engine, + "load_workflow_definition", + lambda workflow_id: WorkflowDefinition.model_validate( + { + "workflow_id": workflow_id, + "version": 2, + "steps": [ + {"component": "feature_source"}, + {"component": "download_dataset"}, + {"component": "temporal_aggregation", "config": {"invalid_key": 1}}, + {"component": "spatial_aggregation"}, + {"component": "build_datavalueset"}, + ], + } + ), + ) + monkeypatch.setattr( + engine, + "get_dataset", + lambda dataset_id: {"id": "chirps3_precipitation_daily", "variable": "precip"}, + ) + persisted: dict[str, Any] = {} + + def _persist_run_log(**kwargs: Any) -> str: + persisted.update(kwargs) + return "/tmp/data/workflow_runs/run.json" + + monkeypatch.setattr(engine, "persist_run_log", _persist_run_log) + monkeypatch.setattr( + engine.component_services, + "feature_source_component", + lambda config: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0, 0, 1, 1], + ), + ) + monkeypatch.setattr(engine.component_services, "download_dataset_component", lambda **kwargs: None) + + with pytest.raises(HTTPException) as exc_info: + engine.execute_workflow(request) + + assert exc_info.value.status_code == 422 + detail = cast(dict[str, Any], exc_info.value.detail) + assert detail["error"] == "workflow_execution_failed" + assert detail["error_code"] == "CONFIG_VALIDATION_FAILED" + assert detail["failed_component"] == "temporal_aggregation" + assert detail["failed_component_version"] == "v1" + assert persisted["error_code"] == "CONFIG_VALIDATION_FAILED" + assert persisted["failed_component"] == "temporal_aggregation" + assert persisted["failed_component_version"] == "v1" From 6965a84ab0152d7b4119ec77f98743db2ae67b58 Mon Sep 17 00:00:00 2001 From: Abyot Asalefew Gizaw Date: Thu, 12 Mar 2026 13:36:46 +0100 Subject: [PATCH 5/5] feat: finalize workflow assembly validation and mixed local/remote orchestration --- docs/workflow-orchestration.md | 172 +++++-- src/eo_api/components/routes.py | 31 +- src/eo_api/components/schemas.py | 12 +- src/eo_api/components/services.py | 58 ++- src/eo_api/workflows/routes.py | 87 +++- src/eo_api/workflows/schemas.py | 54 +++ src/eo_api/workflows/services/engine.py | 494 +++++++++++++++++--- src/eo_api/workflows/services/spatial.py | 9 +- tests/test_workflows.py | 550 ++++++++++++++++++++++- 9 files changed, 1312 insertions(+), 155 deletions(-) diff --git a/docs/workflow-orchestration.md b/docs/workflow-orchestration.md index 97d26a1..279f0c8 100644 --- a/docs/workflow-orchestration.md +++ b/docs/workflow-orchestration.md @@ -20,7 +20,7 @@ The current implementation provides: 1. One canonical workflow execution endpoint: - `POST /workflows/dhis2-datavalue-set` -2. One public flat request payload contract (`WorkflowRequest`). +2. One public wrapped request payload contract (`{"request": WorkflowRequest}`). 3. Internal normalization into a canonical execution model (`WorkflowExecuteRequest`). 4. A fixed generic orchestration chain with exactly 5 components: - `feature_source` @@ -45,6 +45,8 @@ The current implementation provides: ### Primary Workflow Endpoint - `POST /workflows/dhis2-datavalue-set` +- `POST /workflows/execute` (inline assembly execution: post `workflow.steps` + `request` payload) +- `POST /workflows/validate` (validate discovered/inline workflow + request compatibility without execution) ### Workflow Discovery Endpoint @@ -52,7 +54,8 @@ The current implementation provides: ### Component Discovery/Execution Endpoints -- `GET /components` +- `GET /components` (public catalog; hides internal orchestration-only config schema) +- `GET /components?include_internal=true` (internal/debug catalog including component config schema) - `POST /components/feature-source` - `POST /components/download-dataset` - `POST /components/temporal-aggregation` @@ -65,20 +68,22 @@ The current implementation provides: ## Public Workflow Request Contract -The workflow endpoint accepts one flat payload shape: +The workflow endpoint accepts one wrapped payload shape: ```json { - "workflow_id": "dhis2_datavalue_set_v1", - "dataset_id": "chirps3_precipitation_daily", - "start_date": "2024-01-01", - "end_date": "2024-05-31", - "org_unit_level": 2, - "data_element": "DE_UID", - "temporal_resolution": "monthly", - "temporal_reducer": "sum", - "spatial_reducer": "mean", - "include_component_run_details": false + "request": { + "workflow_id": "dhis2_datavalue_set_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-05-31", + "org_unit_level": 2, + "data_element": "DE_UID", + "temporal_resolution": "monthly", + "temporal_reducer": "sum", + "spatial_reducer": "mean", + "include_component_run_details": false + } } ``` @@ -106,7 +111,7 @@ Notes: File: `src/eo_api/workflows/services/simple_mapper.py` -Public flat payload is normalized to internal `WorkflowExecuteRequest` with component-ready nested configs: +Public wrapped payload (`request`) is normalized to internal `WorkflowExecuteRequest` with component-ready nested configs: 1. `feature_source` config: - `org_unit_level` -> `source_type=dhis2_level` @@ -251,6 +256,50 @@ Workflow step schema now supports: 2. `version` (default `v1`) 3. `config` (default `{}`) +### Remote Component Execution + +All five components support either local (default) or remote API execution. + +Common step config options: + +1. `execution_mode`: `local` or `remote` (default `local`) +2. `remote_url`: required when `execution_mode=remote` (expects component-compatible POST endpoint) +3. `remote_timeout_sec`: request timeout (default `30`) +4. `remote_retries`: number of attempts (default `1`) +5. `remote_retry_delay_sec`: delay between attempts in seconds (default `1`) +6. Component-specific options remain available (for example `overwrite`, `country_code`, `method`, `period_type`) + +Example: + +```yaml +steps: + - component: feature_source + version: v1 + config: + execution_mode: remote + remote_url: "http://component-host/components/feature-source" + - component: download_dataset + version: v1 + config: + execution_mode: remote + remote_url: "http://component-host/components/download-dataset" + - component: temporal_aggregation + version: v1 + config: + execution_mode: remote + remote_url: "http://component-host/components/temporal-aggregation" + - component: spatial_aggregation + version: v1 + config: + execution_mode: remote + remote_url: "http://component-host/components/spatial-aggregation" + - component: build_datavalueset + version: v1 + config: + execution_mode: remote + remote_url: "http://component-host/components/build-datavalue-set" +``` + The default YAML remains the same 5-step sequence, but the engine reads it declaratively and dispatches components through a registry map. --- @@ -446,16 +495,18 @@ curl -s http://127.0.0.1:8000/workflows | jq curl -s -X POST "http://127.0.0.1:8000/workflows/dhis2-datavalue-set" \ -H "Content-Type: application/json" \ -d '{ - "workflow_id": "dhis2_datavalue_set_v1", - "dataset_id": "chirps3_precipitation_daily", - "start_date": "2024-01-01", - "end_date": "2024-02-29", - "org_unit_level": 2, - "data_element": "DE_UID", - "temporal_resolution": "monthly", - "temporal_reducer": "sum", - "spatial_reducer": "mean", - "include_component_run_details": true + "request": { + "workflow_id": "dhis2_datavalue_set_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-02-29", + "org_unit_level": 2, + "data_element": "DE_UID", + "temporal_resolution": "monthly", + "temporal_reducer": "sum", + "spatial_reducer": "mean", + "include_component_run_details": true + } }' | jq ``` @@ -473,14 +524,16 @@ Expected component order: curl -s -X POST "http://127.0.0.1:8000/workflows/dhis2-datavalue-set" \ -H "Content-Type: application/json" \ -d '{ - "workflow_id": "dhis2_datavalue_set_without_temporal_aggregation_v1", - "dataset_id": "chirps3_precipitation_daily", - "start_date": "2024-01-01", - "end_date": "2024-02-29", - "org_unit_level": 2, - "data_element": "DE_UID", - "spatial_reducer": "mean", - "include_component_run_details": true + "request": { + "workflow_id": "dhis2_datavalue_set_without_temporal_aggregation_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-02-29", + "org_unit_level": 2, + "data_element": "DE_UID", + "spatial_reducer": "mean", + "include_component_run_details": true + } }' | jq ``` @@ -497,19 +550,60 @@ Expected component order: curl -s -X POST "http://127.0.0.1:8000/workflows/dhis2-datavalue-set" \ -H "Content-Type: application/json" \ -d '{ - "workflow_id": "does_not_exist", - "dataset_id": "chirps3_precipitation_daily", - "start_date": "2024-01-01", - "end_date": "2024-01-31", - "org_unit_level": 2, - "data_element": "DE_UID" + "request": { + "workflow_id": "does_not_exist", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "org_unit_level": 2, + "data_element": "DE_UID" + } }' | jq ``` Expected result: `422` with allowed/discovered `workflow_id` values in error detail. +6. Validate inline assembly (no execution): + +```bash +curl -s -X POST "http://127.0.0.1:8000/workflows/validate" \ + -H "Content-Type: application/json" \ + -d '{ + "workflow": { + "workflow_id": "adhoc_validate_v1", + "version": 1, + "steps": [ + {"component": "feature_source", "version": "v1", "config": {}}, + {"component": "download_dataset", "version": "v1", "config": {}}, + {"component": "spatial_aggregation", "version": "v1", "config": {}}, + {"component": "build_datavalueset", "version": "v1", "config": {}} + ] + }, + "request": { + "workflow_id": "adhoc_validate_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "org_unit_level": 2, + "data_element": "DE_UID" + } + }' | jq +``` + +Expected result: `200` with `valid: true`, resolved step configs, and no execution side effects. + --- ## Next Technical Step -Add a workflow governance model for multi-user environments: workflow metadata (owner/status), promotion states (draft/staging/prod), and optional signature/checksum validation before a discovered YAML can execute. +Prioritize orchestration-tool readiness (Prefect/Airflow wrappers over the current workflow service) before any OGC-first migration. + +Rationale: + +1. Delivers immediate operational value (scheduling, retries, long-running reliability) with minimal API churn. +2. Reuses existing componentization, dispatcher, and run metadata. +3. Avoids a high-risk architecture pivot while the current workflow contract is stabilizing. + +For detailed option synthesis and implementation scope, see: + +- `docs/internal/roadmap_v2.md` (Post-V2 Decision Synthesis) diff --git a/src/eo_api/components/routes.py b/src/eo_api/components/routes.py index 73519c2..3754535 100644 --- a/src/eo_api/components/routes.py +++ b/src/eo_api/components/routes.py @@ -2,7 +2,10 @@ from __future__ import annotations -from fastapi import APIRouter +from typing import Any + +import numpy as np +from fastapi import APIRouter, Query from ..data_manager.services.constants import BBOX from . import services @@ -23,10 +26,24 @@ router = APIRouter() -@router.get("/components", response_model=ComponentCatalogResponse) -def list_components() -> ComponentCatalogResponse: +def _to_jsonable_scalar(value: Any) -> Any: + """Convert numpy scalars/datetimes to JSON-safe native values.""" + if isinstance(value, np.datetime64): + return np.datetime_as_string(value, unit="s") + if isinstance(value, np.generic): + return value.item() + return value + + +def _json_safe_records(records: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Ensure record rows are JSON-serializable.""" + return [{key: _to_jsonable_scalar(value) for key, value in record.items()} for record in records] + + +@router.get("/components", response_model=ComponentCatalogResponse, response_model_exclude_none=True) +def list_components(include_internal: bool = Query(default=False)) -> ComponentCatalogResponse: """List all discoverable reusable components.""" - return ComponentCatalogResponse(components=services.component_catalog()) + return ComponentCatalogResponse(components=services.component_catalog(include_internal=include_internal)) @router.post("/components/feature-source", response_model=FeatureSourceRunResponse) @@ -94,10 +111,12 @@ def run_spatial_aggregation(payload: SpatialAggregationRunRequest) -> SpatialAgg method=payload.method, feature_id_property=payload.feature_id_property, ) + json_records = _json_safe_records(records) return SpatialAggregationRunResponse( dataset_id=payload.dataset_id, - record_count=len(records), - preview=records[: payload.max_preview_rows], + record_count=len(json_records), + preview=json_records[: payload.max_preview_rows], + records=json_records if payload.include_records else None, ) diff --git a/src/eo_api/components/schemas.py b/src/eo_api/components/schemas.py index a0920dc..1c2d684 100644 --- a/src/eo_api/components/schemas.py +++ b/src/eo_api/components/schemas.py @@ -14,6 +14,13 @@ ) +class ComponentEndpoint(BaseModel): + """HTTP endpoint metadata for a component.""" + + path: str + method: str + + class ComponentDefinition(BaseModel): """Component metadata for discovery.""" @@ -23,9 +30,10 @@ class ComponentDefinition(BaseModel): inputs: list[str] outputs: list[str] input_schema: dict[str, Any] = Field(default_factory=dict) - config_schema: dict[str, Any] = Field(default_factory=dict) + config_schema: dict[str, Any] | None = None output_schema: dict[str, Any] = Field(default_factory=dict) error_codes: list[str] = Field(default_factory=list) + endpoint: ComponentEndpoint class ComponentCatalogResponse(BaseModel): @@ -98,6 +106,7 @@ class SpatialAggregationRunRequest(BaseModel): method: AggregationMethod = AggregationMethod.MEAN bbox: list[float] | None = None feature_id_property: str = "id" + include_records: bool = False max_preview_rows: int = 20 @@ -107,6 +116,7 @@ class SpatialAggregationRunResponse(BaseModel): dataset_id: str record_count: int preview: list[dict[str, Any]] + records: list[dict[str, Any]] | None = None class BuildDataValueSetRunRequest(BaseModel): diff --git a/src/eo_api/components/services.py b/src/eo_api/components/services.py index 8bd0e1d..a546051 100644 --- a/src/eo_api/components/services.py +++ b/src/eo_api/components/services.py @@ -21,7 +21,7 @@ from ..workflows.services.preflight import check_upstream_connectivity from ..workflows.services.spatial import aggregate_to_features from ..workflows.services.temporal import aggregate_temporal -from .schemas import ComponentDefinition +from .schemas import ComponentDefinition, ComponentEndpoint _ERROR_CODES_V1: Final[list[str]] = [ "INPUT_VALIDATION_FAILED", @@ -43,7 +43,17 @@ "properties": {"feature_source": {"type": "object"}}, "required": ["feature_source"], }, - config_schema={"type": "object", "properties": {}, "additionalProperties": False}, + config_schema={ + "type": "object", + "properties": { + "execution_mode": {"type": "string", "enum": ["local", "remote"]}, + "remote_url": {"type": ["string", "null"]}, + "remote_timeout_sec": {"type": "number"}, + "remote_retries": {"type": "integer"}, + "remote_retry_delay_sec": {"type": "number"}, + }, + "additionalProperties": False, + }, output_schema={ "type": "object", "properties": { @@ -53,6 +63,7 @@ "required": ["features", "bbox"], }, error_codes=_ERROR_CODES_V1, + endpoint=ComponentEndpoint(path="/components/feature-source", method="POST"), ), "download_dataset@v1": ComponentDefinition( name="download_dataset", @@ -75,13 +86,17 @@ config_schema={ "type": "object", "properties": { - "overwrite": {"type": "boolean"}, - "country_code": {"type": ["string", "null"]}, + "execution_mode": {"type": "string", "enum": ["local", "remote"]}, + "remote_url": {"type": ["string", "null"]}, + "remote_timeout_sec": {"type": "number"}, + "remote_retries": {"type": "integer"}, + "remote_retry_delay_sec": {"type": "number"}, }, "additionalProperties": False, }, output_schema={"type": "object", "properties": {"status": {"type": "string"}}}, error_codes=_ERROR_CODES_V1, + endpoint=ComponentEndpoint(path="/components/download-dataset", method="POST"), ), "temporal_aggregation@v1": ComponentDefinition( name="temporal_aggregation", @@ -104,13 +119,17 @@ config_schema={ "type": "object", "properties": { - "target_period_type": {"type": "string"}, - "method": {"type": "string"}, + "execution_mode": {"type": "string", "enum": ["local", "remote"]}, + "remote_url": {"type": ["string", "null"]}, + "remote_timeout_sec": {"type": "number"}, + "remote_retries": {"type": "integer"}, + "remote_retry_delay_sec": {"type": "number"}, }, "additionalProperties": False, }, output_schema={"type": "object", "properties": {"dataset": {"type": "object"}}}, error_codes=_ERROR_CODES_V1, + endpoint=ComponentEndpoint(path="/components/temporal-aggregation", method="POST"), ), "spatial_aggregation@v1": ComponentDefinition( name="spatial_aggregation", @@ -132,13 +151,17 @@ config_schema={ "type": "object", "properties": { - "method": {"type": "string"}, - "feature_id_property": {"type": "string"}, + "execution_mode": {"type": "string", "enum": ["local", "remote"]}, + "remote_url": {"type": ["string", "null"]}, + "remote_timeout_sec": {"type": "number"}, + "remote_retries": {"type": "integer"}, + "remote_retry_delay_sec": {"type": "number"}, }, "additionalProperties": False, }, output_schema={"type": "object", "properties": {"records": {"type": "array"}}}, error_codes=_ERROR_CODES_V1, + endpoint=ComponentEndpoint(path="/components/spatial-aggregation", method="POST"), ), "build_datavalueset@v1": ComponentDefinition( name="build_datavalueset", @@ -159,7 +182,11 @@ config_schema={ "type": "object", "properties": { - "period_type": {"type": "string"}, + "execution_mode": {"type": "string", "enum": ["local", "remote"]}, + "remote_url": {"type": ["string", "null"]}, + "remote_timeout_sec": {"type": "number"}, + "remote_retries": {"type": "integer"}, + "remote_retry_delay_sec": {"type": "number"}, }, "additionalProperties": False, }, @@ -169,13 +196,20 @@ "required": ["data_value_set", "output_file"], }, error_codes=_ERROR_CODES_V1, + endpoint=ComponentEndpoint(path="/components/build-datavalue-set", method="POST"), ), } -def component_catalog() -> list[ComponentDefinition]: - """Return all discoverable component definitions.""" - return list(_COMPONENT_REGISTRY.values()) +def component_catalog(*, include_internal: bool = False) -> list[ComponentDefinition]: + """Return discoverable component definitions. + + By default, internal orchestration-only metadata (config_schema) is hidden. + """ + components = list(_COMPONENT_REGISTRY.values()) + if include_internal: + return components + return [component.model_copy(update={"config_schema": None}) for component in components] def component_registry() -> dict[str, ComponentDefinition]: diff --git a/src/eo_api/workflows/routes.py b/src/eo_api/workflows/routes.py index 6716c51..8650f54 100644 --- a/src/eo_api/workflows/routes.py +++ b/src/eo_api/workflows/routes.py @@ -2,9 +2,18 @@ from fastapi import APIRouter, HTTPException -from .schemas import WorkflowCatalogItem, WorkflowCatalogResponse, WorkflowExecuteResponse, WorkflowRequest -from .services.definitions import list_workflow_definitions -from .services.engine import execute_workflow +from .schemas import ( + WorkflowAssemblyExecuteRequest, + WorkflowCatalogItem, + WorkflowCatalogResponse, + WorkflowExecuteEnvelopeRequest, + WorkflowExecuteResponse, + WorkflowValidateRequest, + WorkflowValidateResponse, + WorkflowValidateStep, +) +from .services.definitions import list_workflow_definitions, load_workflow_definition +from .services.engine import execute_workflow, validate_workflow_steps from .services.simple_mapper import normalize_simple_request router = APIRouter() @@ -31,12 +40,74 @@ def list_workflows() -> WorkflowCatalogResponse: @router.post("/dhis2-datavalue-set", response_model=WorkflowExecuteResponse) -def run_dhis2_datavalue_set_workflow(payload: WorkflowRequest) -> WorkflowExecuteResponse: +def run_dhis2_datavalue_set_workflow(payload: WorkflowExecuteEnvelopeRequest) -> WorkflowExecuteResponse: """Run workflow from a single flat request payload.""" - request, _warnings = normalize_simple_request(payload) + request, _warnings = normalize_simple_request(payload.request) return execute_workflow( request, - workflow_id=payload.workflow_id, - request_params=payload.model_dump(), - include_component_run_details=payload.include_component_run_details, + workflow_id=payload.request.workflow_id, + request_params=payload.request.model_dump(), + include_component_run_details=payload.request.include_component_run_details, + ) + + +@router.post("/execute", response_model=WorkflowExecuteResponse) +def run_inline_assembled_workflow(payload: WorkflowAssemblyExecuteRequest) -> WorkflowExecuteResponse: + """Run an inline assembled workflow definition from one flat request payload.""" + request, _warnings = normalize_simple_request(payload.request) + return execute_workflow( + request, + workflow_id=payload.workflow.workflow_id, + workflow_definition=payload.workflow, + request_params=payload.request.model_dump(exclude_none=True), + include_component_run_details=payload.request.include_component_run_details, + ) + + +@router.post("/validate", response_model=WorkflowValidateResponse) +def validate_workflow_assembly(payload: WorkflowValidateRequest) -> WorkflowValidateResponse: + """Validate workflow assembly without executing any component.""" + warnings: list[str] = [] + errors: list[str] = [] + + try: + if payload.workflow is not None: + workflow = payload.workflow + else: + workflow = load_workflow_definition(payload.workflow_id or "") + except ValueError as exc: + return WorkflowValidateResponse( + valid=False, + workflow_id=payload.workflow_id or "unknown", + workflow_version=0, + step_count=0, + components=[], + warnings=warnings, + errors=[str(exc)], + ) + + request_params: dict[str, object] = {} + if payload.request is not None: + _request, map_warnings = normalize_simple_request(payload.request) + warnings.extend(map_warnings) + request_params = payload.request.model_dump(exclude_none=True) + + try: + resolved_steps = [ + WorkflowValidateStep.model_validate(step) + for step in validate_workflow_steps(workflow=workflow, request_params=request_params) + ] + except ValueError as exc: + errors.append(str(exc)) + resolved_steps = [] + + return WorkflowValidateResponse( + valid=not errors, + workflow_id=workflow.workflow_id, + workflow_version=workflow.version, + step_count=len(workflow.steps), + components=[step.component for step in workflow.steps], + resolved_steps=resolved_steps, + warnings=warnings, + errors=errors, ) diff --git a/src/eo_api/workflows/schemas.py b/src/eo_api/workflows/schemas.py index cfd0645..09072dc 100644 --- a/src/eo_api/workflows/schemas.py +++ b/src/eo_api/workflows/schemas.py @@ -5,6 +5,8 @@ from pydantic import BaseModel, Field, model_validator +from .services.definitions import WorkflowDefinition + class FeatureSourceType(StrEnum): """Supported feature source backends.""" @@ -173,3 +175,55 @@ def validate_time_window(self) -> "WorkflowRequest": if self.org_unit_level is None and not self.org_unit_ids: raise ValueError("Provide org_unit_level or org_unit_ids") return self + + +class WorkflowExecuteEnvelopeRequest(BaseModel): + """Envelope for workflow execution input payload.""" + + request: WorkflowRequest + + +class WorkflowAssemblyExecuteRequest(BaseModel): + """Inline workflow assembly + wrapped public workflow input.""" + + request: WorkflowRequest + workflow: WorkflowDefinition + + +class WorkflowValidateRequest(BaseModel): + """Validation request for discovered or inline workflow assembly.""" + + workflow_id: str | None = None + workflow: WorkflowDefinition | None = None + request: WorkflowRequest | None = None + + @model_validator(mode="after") + def validate_workflow_source(self) -> "WorkflowValidateRequest": + """Require exactly one workflow source.""" + if (self.workflow_id is None and self.workflow is None) or ( + self.workflow_id is not None and self.workflow is not None + ): + raise ValueError("Provide exactly one of workflow_id or workflow") + return self + + +class WorkflowValidateStep(BaseModel): + """Resolved workflow step metadata from validation.""" + + index: int + component: str + version: str + resolved_config: dict[str, Any] + + +class WorkflowValidateResponse(BaseModel): + """Validation result for a workflow assembly.""" + + valid: bool + workflow_id: str + workflow_version: int + step_count: int + components: list[str] + resolved_steps: list[WorkflowValidateStep] = Field(default_factory=list) + warnings: list[str] = Field(default_factory=list) + errors: list[str] = Field(default_factory=list) diff --git a/src/eo_api/workflows/services/engine.py b/src/eo_api/workflows/services/engine.py index d8e1ad7..5f91716 100644 --- a/src/eo_api/workflows/services/engine.py +++ b/src/eo_api/workflows/services/engine.py @@ -2,15 +2,17 @@ from __future__ import annotations +import time from collections.abc import Callable from typing import Any +import httpx from fastapi import HTTPException from pydantic import BaseModel, ConfigDict, ValidationError from ...components import services as component_services from ...data_registry.services.datasets import get_dataset -from ..schemas import AggregationMethod, PeriodType, WorkflowExecuteRequest, WorkflowExecuteResponse +from ..schemas import WorkflowExecuteRequest, WorkflowExecuteResponse from .definitions import WorkflowDefinition, load_workflow_definition from .run_logs import persist_run_log from .runtime import WorkflowRuntime @@ -39,6 +41,7 @@ def execute_workflow( request: WorkflowExecuteRequest, *, workflow_id: str = "dhis2_datavalue_set_v1", + workflow_definition: WorkflowDefinition | None = None, request_params: dict[str, Any] | None = None, include_component_run_details: bool = False, ) -> WorkflowExecuteResponse: @@ -52,10 +55,13 @@ def execute_workflow( context: dict[str, Any] = {} try: - try: - workflow = load_workflow_definition(workflow_id) - except ValueError as exc: - raise HTTPException(status_code=422, detail=str(exc)) from exc + if workflow_definition is not None: + workflow = workflow_definition + else: + try: + workflow = load_workflow_definition(workflow_id) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc _execute_workflow_steps( workflow=workflow, runtime=runtime, @@ -229,6 +235,31 @@ def _execute_workflow_steps( context.update(updates) +def validate_workflow_steps( + *, + workflow: WorkflowDefinition, + request_params: dict[str, Any] | None = None, +) -> list[dict[str, Any]]: + """Resolve and validate step configs without executing components.""" + resolved_steps: list[dict[str, Any]] = [] + params = request_params or {} + for index, step in enumerate(workflow.steps): + try: + resolved_config = _resolve_step_config(step.config, params) + _validate_step_config(step.component, step.version, resolved_config) + except ValueError as exc: + raise ValueError(f"Step {index + 1} ({step.component}@{step.version}) validation failed: {exc}") from exc + resolved_steps.append( + { + "index": index + 1, + "component": step.component, + "version": step.version, + "resolved_config": resolved_config, + } + ) + return resolved_steps + + type StepExecutor = Callable[..., dict[str, Any]] @@ -240,12 +271,24 @@ def _run_feature_source( context: dict[str, Any], step_config: dict[str, Any], ) -> dict[str, Any]: - del dataset, context, step_config - features, bbox = runtime.run( - "feature_source", - component_services.feature_source_component, - config=request.feature_source, - ) + del dataset, context + execution_mode = str(step_config.get("execution_mode", "local")).lower() + if execution_mode == "remote": + features, bbox = runtime.run( + "feature_source", + _invoke_remote_feature_source_component, + remote_url=str(step_config["remote_url"]), + feature_source=request.feature_source.model_dump(mode="json"), + timeout_sec=float(step_config.get("remote_timeout_sec", 30.0)), + retries=int(step_config.get("remote_retries", 1)), + retry_delay_sec=float(step_config.get("remote_retry_delay_sec", 1.0)), + ) + else: + features, bbox = runtime.run( + "feature_source", + component_services.feature_source_component, + config=request.feature_source, + ) return {"features": features, "bbox": bbox} @@ -257,18 +300,45 @@ def _run_download_dataset( context: dict[str, Any], step_config: dict[str, Any], ) -> dict[str, Any]: - overwrite = bool(step_config.get("overwrite", request.overwrite)) - country_code = step_config.get("country_code", request.country_code) - runtime.run( - "download_dataset", - component_services.download_dataset_component, - dataset=dataset, - start=request.start, - end=request.end, - overwrite=overwrite, - country_code=country_code, - bbox=_require_context(context, "bbox"), - ) + execution_mode = str(step_config.get("execution_mode", "local")).lower() + if execution_mode not in {"local", "remote"}: + raise ValueError("download_dataset.execution_mode must be 'local' or 'remote'") + + overwrite = request.overwrite + country_code = request.country_code + bbox = _require_context(context, "bbox") + if execution_mode == "remote": + remote_url = step_config.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + raise ValueError("download_dataset remote mode requires non-empty 'remote_url'") + remote_timeout = float(step_config.get("remote_timeout_sec", 30.0)) + remote_retries = int(step_config.get("remote_retries", 1)) + remote_retry_delay_sec = float(step_config.get("remote_retry_delay_sec", 1.0)) + runtime.run( + "download_dataset", + _invoke_remote_download_component, + remote_url=remote_url, + dataset_id=request.dataset_id, + start=request.start, + end=request.end, + overwrite=overwrite, + country_code=country_code, + bbox=bbox, + timeout_sec=remote_timeout, + retries=remote_retries, + retry_delay_sec=remote_retry_delay_sec, + ) + else: + runtime.run( + "download_dataset", + component_services.download_dataset_component, + dataset=dataset, + start=request.start, + end=request.end, + overwrite=overwrite, + country_code=country_code, + bbox=bbox, + ) return {} @@ -280,20 +350,35 @@ def _run_temporal_aggregation( context: dict[str, Any], step_config: dict[str, Any], ) -> dict[str, Any]: - target_period_type = PeriodType( - str(step_config.get("target_period_type", request.temporal_aggregation.target_period_type)) - ) - method = AggregationMethod(str(step_config.get("method", request.temporal_aggregation.method))) - temporal_ds = runtime.run( - "temporal_aggregation", - component_services.temporal_aggregation_component, - dataset=dataset, - start=request.start, - end=request.end, - bbox=_require_context(context, "bbox"), - target_period_type=target_period_type, - method=method, - ) + target_period_type = request.temporal_aggregation.target_period_type + method = request.temporal_aggregation.method + execution_mode = str(step_config.get("execution_mode", "local")).lower() + if execution_mode == "remote": + temporal_ds = runtime.run( + "temporal_aggregation", + _invoke_remote_temporal_aggregation_component, + remote_url=str(step_config["remote_url"]), + dataset_id=request.dataset_id, + start=request.start, + end=request.end, + bbox=_require_context(context, "bbox"), + target_period_type=target_period_type.value, + method=method.value, + timeout_sec=float(step_config.get("remote_timeout_sec", 30.0)), + retries=int(step_config.get("remote_retries", 1)), + retry_delay_sec=float(step_config.get("remote_retry_delay_sec", 1.0)), + ) + else: + temporal_ds = runtime.run( + "temporal_aggregation", + component_services.temporal_aggregation_component, + dataset=dataset, + start=request.start, + end=request.end, + bbox=_require_context(context, "bbox"), + target_period_type=target_period_type, + method=method, + ) return {"temporal_dataset": temporal_ds} @@ -305,19 +390,37 @@ def _run_spatial_aggregation( context: dict[str, Any], step_config: dict[str, Any], ) -> dict[str, Any]: - method = AggregationMethod(str(step_config.get("method", request.spatial_aggregation.method))) - feature_id_property = str(step_config.get("feature_id_property", request.dhis2.org_unit_property)) - records = runtime.run( - "spatial_aggregation", - component_services.spatial_aggregation_component, - dataset=dataset, - start=request.start, - end=request.end, - bbox=_require_context(context, "bbox"), - features=_require_context(context, "features"), - method=method, - feature_id_property=feature_id_property, - ) + method = request.spatial_aggregation.method + feature_id_property = request.dhis2.org_unit_property + execution_mode = str(step_config.get("execution_mode", "local")).lower() + if execution_mode == "remote": + records = runtime.run( + "spatial_aggregation", + _invoke_remote_spatial_aggregation_component, + remote_url=str(step_config["remote_url"]), + dataset_id=request.dataset_id, + start=request.start, + end=request.end, + bbox=_require_context(context, "bbox"), + feature_source=request.feature_source.model_dump(mode="json"), + method=method.value, + feature_id_property=feature_id_property, + timeout_sec=float(step_config.get("remote_timeout_sec", 30.0)), + retries=int(step_config.get("remote_retries", 1)), + retry_delay_sec=float(step_config.get("remote_retry_delay_sec", 1.0)), + ) + else: + records = runtime.run( + "spatial_aggregation", + component_services.spatial_aggregation_component, + dataset=dataset, + start=request.start, + end=request.end, + bbox=_require_context(context, "bbox"), + features=_require_context(context, "features"), + method=method, + feature_id_property=feature_id_property, + ) return {"records": records} @@ -330,15 +433,30 @@ def _run_build_datavalueset( step_config: dict[str, Any], ) -> dict[str, Any]: del dataset - period_type = PeriodType(str(step_config.get("period_type", request.temporal_aggregation.target_period_type))) - data_value_set, output_file = runtime.run( - "build_datavalueset", - component_services.build_datavalueset_component, - records=_require_context(context, "records"), - dataset_id=request.dataset_id, - period_type=period_type, - dhis2=request.dhis2, - ) + period_type = request.temporal_aggregation.target_period_type + execution_mode = str(step_config.get("execution_mode", "local")).lower() + if execution_mode == "remote": + data_value_set, output_file = runtime.run( + "build_datavalueset", + _invoke_remote_build_datavalueset_component, + remote_url=str(step_config["remote_url"]), + dataset_id=request.dataset_id, + period_type=period_type.value, + records=_require_context(context, "records"), + dhis2=request.dhis2.model_dump(mode="json"), + timeout_sec=float(step_config.get("remote_timeout_sec", 30.0)), + retries=int(step_config.get("remote_retries", 1)), + retry_delay_sec=float(step_config.get("remote_retry_delay_sec", 1.0)), + ) + else: + data_value_set, output_file = runtime.run( + "build_datavalueset", + component_services.build_datavalueset_component, + records=_require_context(context, "records"), + dataset_id=request.dataset_id, + period_type=period_type, + dhis2=request.dhis2, + ) return {"data_value_set": data_value_set, "output_file": output_file} @@ -374,32 +492,51 @@ def _resolve_value(value: Any, request_params: dict[str, Any]) -> Any: class _FeatureSourceStepConfig(BaseModel): model_config = ConfigDict(extra="forbid") + execution_mode: str = "local" + remote_url: str | None = None + remote_timeout_sec: float = 30.0 + remote_retries: int = 1 + remote_retry_delay_sec: float = 1.0 + class _DownloadDatasetStepConfig(BaseModel): model_config = ConfigDict(extra="forbid") - overwrite: bool | None = None - country_code: str | None = None + execution_mode: str = "local" + remote_url: str | None = None + remote_timeout_sec: float = 30.0 + remote_retries: int = 1 + remote_retry_delay_sec: float = 1.0 class _TemporalAggregationStepConfig(BaseModel): model_config = ConfigDict(extra="forbid") - target_period_type: PeriodType | None = None - method: AggregationMethod | None = None + execution_mode: str = "local" + remote_url: str | None = None + remote_timeout_sec: float = 30.0 + remote_retries: int = 1 + remote_retry_delay_sec: float = 1.0 class _SpatialAggregationStepConfig(BaseModel): model_config = ConfigDict(extra="forbid") - method: AggregationMethod | None = None - feature_id_property: str | None = None + execution_mode: str = "local" + remote_url: str | None = None + remote_timeout_sec: float = 30.0 + remote_retries: int = 1 + remote_retry_delay_sec: float = 1.0 class _BuildDataValueSetStepConfig(BaseModel): model_config = ConfigDict(extra="forbid") - period_type: PeriodType | None = None + execution_mode: str = "local" + remote_url: str | None = None + remote_timeout_sec: float = 30.0 + remote_retries: int = 1 + remote_retry_delay_sec: float = 1.0 _STEP_CONFIG_MODELS: dict[str, type[BaseModel]] = { @@ -419,6 +556,229 @@ def _validate_step_config(component: str, version: str, config: dict[str, Any]) if model is None: raise ValueError(f"No config schema registered for component '{component}'") try: - model.model_validate(config) + validated = model.model_validate(config) except ValidationError as exc: raise ValueError(f"Invalid config for component '{component}@{version}': {exc}") from exc + mode = str(getattr(validated, "execution_mode", "local")).lower() + if mode not in {"local", "remote"}: + raise ValueError( + f"Invalid config for component '{component}@{version}': execution_mode must be local or remote" + ) + remote_url = getattr(validated, "remote_url", None) + remote_timeout_sec = getattr(validated, "remote_timeout_sec", 30.0) + remote_retries = getattr(validated, "remote_retries", 1) + remote_retry_delay_sec = getattr(validated, "remote_retry_delay_sec", 1.0) + + has_remote_config = bool( + (isinstance(remote_url, str) and remote_url.strip()) + or float(remote_timeout_sec) != 30.0 + or int(remote_retries) != 1 + or float(remote_retry_delay_sec) != 1.0 + ) + + if mode == "local" and has_remote_config: + raise ValueError( + f"Invalid config for component '{component}@{version}': " + "remote_url/remote_timeout_sec/remote_retries/remote_retry_delay_sec are only allowed in remote mode" + ) + if mode == "remote": + if not isinstance(remote_url, str) or not remote_url.strip(): + raise ValueError( + f"Invalid config for component '{component}@{version}': remote_url is required for remote mode" + ) + + +def _invoke_remote_download_component( + *, + remote_url: str, + dataset_id: str, + start: str, + end: str, + overwrite: bool, + country_code: str | None, + bbox: list[float], + timeout_sec: float, + retries: int, + retry_delay_sec: float, +) -> None: + """Invoke remote download component endpoint with retry/timeout.""" + payload = { + "dataset_id": dataset_id, + "start": start, + "end": end, + "overwrite": overwrite, + "country_code": country_code, + "bbox": bbox, + } + attempts = max(1, retries) + last_exc: Exception | None = None + for attempt in range(1, attempts + 1): + try: + with httpx.Client(timeout=timeout_sec) as client: + response = client.post(remote_url, json=payload) + response.raise_for_status() + return + except Exception as exc: + last_exc = exc + if attempt < attempts: + time.sleep(max(0.0, retry_delay_sec)) + if last_exc is None: + raise RuntimeError("Remote download invocation failed without exception context") + raise last_exc + + +def _invoke_remote_feature_source_component( + *, + remote_url: str, + feature_source: dict[str, Any], + timeout_sec: float, + retries: int, + retry_delay_sec: float, +) -> tuple[dict[str, Any], list[float]]: + """Invoke remote feature-source component endpoint.""" + payload = { + "feature_source": feature_source, + "include_features": True, + } + result = _post_remote_json( + remote_url=remote_url, + payload=payload, + timeout_sec=timeout_sec, + retries=retries, + retry_delay_sec=retry_delay_sec, + ) + features = result.get("features") + bbox = result.get("bbox") + if not isinstance(features, dict) or not isinstance(bbox, list): + raise RuntimeError("Remote feature_source response missing features/bbox") + return features, [float(x) for x in bbox] + + +def _invoke_remote_temporal_aggregation_component( + *, + remote_url: str, + dataset_id: str, + start: str, + end: str, + bbox: list[float], + target_period_type: str, + method: str, + timeout_sec: float, + retries: int, + retry_delay_sec: float, +) -> dict[str, Any]: + """Invoke remote temporal-aggregation component endpoint.""" + payload = { + "dataset_id": dataset_id, + "start": start, + "end": end, + "bbox": bbox, + "target_period_type": target_period_type, + "method": method, + } + return _post_remote_json( + remote_url=remote_url, + payload=payload, + timeout_sec=timeout_sec, + retries=retries, + retry_delay_sec=retry_delay_sec, + ) + + +def _invoke_remote_spatial_aggregation_component( + *, + remote_url: str, + dataset_id: str, + start: str, + end: str, + bbox: list[float], + feature_source: dict[str, Any], + method: str, + feature_id_property: str, + timeout_sec: float, + retries: int, + retry_delay_sec: float, +) -> list[dict[str, Any]]: + """Invoke remote spatial-aggregation component endpoint.""" + payload = { + "dataset_id": dataset_id, + "start": start, + "end": end, + "feature_source": feature_source, + "method": method, + "bbox": bbox, + "feature_id_property": feature_id_property, + "include_records": True, + } + result = _post_remote_json( + remote_url=remote_url, + payload=payload, + timeout_sec=timeout_sec, + retries=retries, + retry_delay_sec=retry_delay_sec, + ) + records = result.get("records") + if not isinstance(records, list): + raise RuntimeError("Remote spatial_aggregation response missing records") + return records + + +def _invoke_remote_build_datavalueset_component( + *, + remote_url: str, + dataset_id: str, + period_type: str, + records: list[dict[str, Any]], + dhis2: dict[str, Any], + timeout_sec: float, + retries: int, + retry_delay_sec: float, +) -> tuple[dict[str, Any], str]: + """Invoke remote build-datavalue-set component endpoint.""" + payload = { + "dataset_id": dataset_id, + "period_type": period_type, + "records": records, + "dhis2": dhis2, + } + result = _post_remote_json( + remote_url=remote_url, + payload=payload, + timeout_sec=timeout_sec, + retries=retries, + retry_delay_sec=retry_delay_sec, + ) + data_value_set = result.get("data_value_set") + output_file = result.get("output_file") + if not isinstance(data_value_set, dict) or not isinstance(output_file, str): + raise RuntimeError("Remote build_datavalueset response missing data_value_set/output_file") + return data_value_set, output_file + + +def _post_remote_json( + *, + remote_url: str, + payload: dict[str, Any], + timeout_sec: float, + retries: int, + retry_delay_sec: float, +) -> dict[str, Any]: + """POST JSON to remote component endpoint with retry and return JSON body.""" + attempts = max(1, retries) + last_exc: Exception | None = None + for attempt in range(1, attempts + 1): + try: + with httpx.Client(timeout=timeout_sec) as client: + response = client.post(remote_url, json=payload) + response.raise_for_status() + body = response.json() + if not isinstance(body, dict): + raise RuntimeError("Remote component returned non-object JSON response") + return body + except Exception as exc: + last_exc = exc + if attempt < attempts: + time.sleep(max(0.0, retry_delay_sec)) + if last_exc is None: + raise RuntimeError("Remote component invocation failed without exception context") + raise last_exc diff --git a/src/eo_api/workflows/services/spatial.py b/src/eo_api/workflows/services/spatial.py index 668141e..8bf3f05 100644 --- a/src/eo_api/workflows/services/spatial.py +++ b/src/eo_api/workflows/services/spatial.py @@ -46,10 +46,17 @@ def aggregate_to_features( for t, value in zip(reduced[time_dim].values, reduced.values, strict=True): if np.isnan(value): continue + # Keep component outputs JSON-safe for direct API exposure and remote execution. + if isinstance(t, np.datetime64): + time_value: Any = np.datetime_as_string(t, unit="s") + elif isinstance(t, np.generic): + time_value = t.item() + else: + time_value = t output.append( { "org_unit": org_unit, - "time": t, + "time": time_value, "value": float(value), } ) diff --git a/tests/test_workflows.py b/tests/test_workflows.py index 4056629..27577db 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -2,6 +2,7 @@ from typing import Any, cast +import numpy as np import pytest import xarray as xr from fastapi import HTTPException @@ -17,17 +18,19 @@ def _valid_public_payload() -> dict[str, Any]: return { - "workflow_id": "dhis2_datavalue_set_v1", - "dataset_id": "chirps3_precipitation_daily", - "start_date": "2024-01-01", - "end_date": "2024-01-31", - "org_unit_level": 3, - "data_element": "abc123def45", - "temporal_resolution": "monthly", - "temporal_reducer": "sum", - "spatial_reducer": "mean", - "dry_run": True, - "include_component_run_details": False, + "request": { + "workflow_id": "dhis2_datavalue_set_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "org_unit_level": 3, + "data_element": "abc123def45", + "temporal_resolution": "monthly", + "temporal_reducer": "sum", + "spatial_reducer": "mean", + "dry_run": True, + "include_component_run_details": False, + } } @@ -37,7 +40,7 @@ def test_workflow_endpoint_exists_once() -> None: for route in app.routes if isinstance(route, APIRoute) and route.path.startswith("/workflows") and "POST" in route.methods } - assert workflow_routes == {"/workflows/dhis2-datavalue-set"} + assert workflow_routes == {"/workflows/dhis2-datavalue-set", "/workflows/execute", "/workflows/validate"} def test_workflow_catalog_endpoint_returns_allowlisted_workflow(client: TestClient) -> None: @@ -85,9 +88,20 @@ def test_components_catalog_endpoint_returns_five_components(client: TestClient) for item in items: assert item["version"] == "v1" assert isinstance(item["input_schema"], dict) - assert isinstance(item["config_schema"], dict) + assert "config_schema" not in item assert isinstance(item["output_schema"], dict) assert "EXECUTION_FAILED" in item["error_codes"] + assert item["endpoint"]["method"] == "POST" + assert item["endpoint"]["path"].startswith("/components/") + + +def test_components_catalog_include_internal_includes_config_schema(client: TestClient) -> None: + response = client.get("/components?include_internal=true") + assert response.status_code == 200 + items = response.json()["components"] + assert len(items) >= 5 + for item in items: + assert isinstance(item["config_schema"], dict) def test_workflow_endpoint_returns_response_shape(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: @@ -146,7 +160,7 @@ def _execute_stub( def test_workflow_endpoint_validates_required_fields(client: TestClient) -> None: payload = _valid_public_payload() - payload.pop("org_unit_level") + payload["request"].pop("org_unit_level") response = client.post("/workflows/dhis2-datavalue-set", json=payload) assert response.status_code == 422 @@ -199,6 +213,205 @@ def _execute_stub( assert response.json()["status"] == "completed" +def test_inline_workflow_execute_endpoint_accepts_assembly(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + stub = WorkflowExecuteResponse( + status="completed", + run_id="run-assembly-123", + workflow_id="adhoc_dhis2_v1", + workflow_version=1, + dataset_id="chirps3_precipitation_daily", + bbox=[-13.3, 6.9, -10.1, 10.0], + feature_count=2, + value_count=4, + output_file="/tmp/data/chirps3_datavalueset.json", + run_log_file="/tmp/data/workflow_runs/run-assembly-123.json", + data_value_set={"dataValues": []}, + component_runs=[], + ) + + def _execute_stub( + payload: Any, + workflow_id: str = "dhis2_datavalue_set_v1", + workflow_definition: WorkflowDefinition | None = None, + request_params: dict[str, Any] | None = None, + include_component_run_details: bool = False, + ) -> WorkflowExecuteResponse: + del payload, request_params, include_component_run_details + assert workflow_id == "adhoc_dhis2_v1" + assert workflow_definition is not None + assert workflow_definition.workflow_id == "adhoc_dhis2_v1" + assert len(workflow_definition.steps) == 4 + return stub + + monkeypatch.setattr("eo_api.workflows.routes.execute_workflow", _execute_stub) + + response = client.post( + "/workflows/execute", + json={ + "workflow": { + "workflow_id": "adhoc_dhis2_v1", + "version": 1, + "steps": [ + {"component": "feature_source", "version": "v1", "config": {}}, + {"component": "download_dataset", "version": "v1", "config": {}}, + {"component": "spatial_aggregation", "version": "v1", "config": {}}, + {"component": "build_datavalueset", "version": "v1", "config": {}}, + ], + }, + "request": { + "workflow_id": "adhoc_dhis2_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "org_unit_level": 3, + "data_element": "abc123def45", + "temporal_resolution": "monthly", + "temporal_reducer": "sum", + "spatial_reducer": "mean", + "include_component_run_details": False, + }, + }, + ) + assert response.status_code == 200 + assert response.json()["workflow_id"] == "adhoc_dhis2_v1" + + +def test_inline_workflow_execute_endpoint_rejects_bad_component_chain(client: TestClient) -> None: + response = client.post( + "/workflows/execute", + json={ + "workflow": { + "workflow_id": "bad_adhoc_v1", + "version": 1, + "steps": [ + {"component": "download_dataset", "version": "v1", "config": {}}, + {"component": "build_datavalueset", "version": "v1", "config": {}}, + ], + }, + "request": { + "workflow_id": "bad_adhoc_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "org_unit_level": 3, + "data_element": "abc123def45", + }, + }, + ) + assert response.status_code == 422 + + +def test_workflow_validate_endpoint_accepts_valid_inline_workflow(client: TestClient) -> None: + response = client.post( + "/workflows/validate", + json={ + "workflow": { + "workflow_id": "adhoc_validate_v1", + "version": 1, + "steps": [ + {"component": "feature_source", "version": "v1", "config": {}}, + {"component": "download_dataset", "version": "v1", "config": {}}, + {"component": "spatial_aggregation", "version": "v1", "config": {}}, + {"component": "build_datavalueset", "version": "v1", "config": {}}, + ], + }, + "request": { + "workflow_id": "adhoc_validate_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "org_unit_level": 3, + "data_element": "abc123def45", + }, + }, + ) + assert response.status_code == 200 + body = response.json() + assert body["valid"] is True + assert body["workflow_id"] == "adhoc_validate_v1" + assert body["step_count"] == 4 + assert len(body["resolved_steps"]) == 4 + assert body["errors"] == [] + + +def test_workflow_validate_endpoint_rejects_runtime_knobs_in_step_config(client: TestClient) -> None: + response = client.post( + "/workflows/validate", + json={ + "workflow": { + "workflow_id": "adhoc_invalid_config_v1", + "version": 1, + "steps": [ + {"component": "feature_source", "version": "v1", "config": {}}, + {"component": "download_dataset", "version": "v1", "config": {"overwrite": True}}, + {"component": "spatial_aggregation", "version": "v1", "config": {}}, + {"component": "build_datavalueset", "version": "v1", "config": {}}, + ], + }, + "request": { + "workflow_id": "adhoc_invalid_config_v1", + "dataset_id": "chirps3_precipitation_daily", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "org_unit_level": 3, + "data_element": "abc123def45", + }, + }, + ) + assert response.status_code == 200 + body = response.json() + assert body["valid"] is False + assert body["resolved_steps"] == [] + assert len(body["errors"]) == 1 + assert "validation failed" in body["errors"][0].lower() + + +def test_workflow_validate_endpoint_unknown_workflow_id(client: TestClient) -> None: + response = client.post("/workflows/validate", json={"workflow_id": "does_not_exist"}) + assert response.status_code == 200 + body = response.json() + assert body["valid"] is False + assert body["step_count"] == 0 + assert len(body["errors"]) == 1 + assert "Unknown workflow_id" in body["errors"][0] + + +def test_component_spatial_aggregation_serializes_numpy_datetime64( + client: TestClient, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr( + "eo_api.components.routes.services.require_dataset", + lambda dataset_id: {"id": dataset_id, "variable": "precip"}, + ) + monkeypatch.setattr( + "eo_api.components.routes.services.feature_source_component", + lambda feature_source: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0.0, 0.0, 1.0, 1.0], + ), + ) + monkeypatch.setattr( + "eo_api.components.routes.services.spatial_aggregation_component", + lambda **kwargs: [{"org_unit": "OU_1", "time": np.datetime64("2024-01-01"), "value": 10.0}], + ) + + response = client.post( + "/components/spatial-aggregation", + json={ + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01", + "end": "2024-01", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 2}, + "method": "mean", + "include_records": True, + }, + ) + assert response.status_code == 200 + body = response.json() + assert body["record_count"] == 1 + assert body["records"][0]["time"] == "2024-01-01T00:00:00" + + def test_engine_orchestrates_components(monkeypatch: pytest.MonkeyPatch) -> None: request = { "dataset_id": "chirps3_precipitation_daily", @@ -528,13 +741,13 @@ def test_engine_resolves_step_config_from_request_params(monkeypatch: pytest.Mon "version": 2, "steps": [ {"component": "feature_source"}, - {"component": "download_dataset"}, + { + "component": "download_dataset", + "config": {"execution_mode": "$request.download_execution_mode"}, + }, { "component": "temporal_aggregation", - "config": { - "method": "$request.temporal_reducer", - "target_period_type": "$request.temporal_resolution", - }, + "config": {}, }, {"component": "spatial_aggregation"}, {"component": "build_datavalueset"}, @@ -558,7 +771,7 @@ def test_engine_resolves_step_config_from_request_params(monkeypatch: pytest.Mon monkeypatch.setattr(engine.component_services, "download_dataset_component", lambda **kwargs: None) def _temporal_component(**kwargs: Any) -> xr.Dataset: - assert kwargs["method"].value == "max" + assert kwargs["method"].value == "sum" assert kwargs["target_period_type"].value == "monthly" return ds @@ -577,7 +790,7 @@ def _temporal_component(**kwargs: Any) -> xr.Dataset: response = engine.execute_workflow( request, - request_params={"temporal_reducer": "max", "temporal_resolution": "monthly"}, + request_params={"download_execution_mode": "local"}, ) assert response.status == "completed" @@ -645,3 +858,298 @@ def _persist_run_log(**kwargs: Any) -> str: assert persisted["error_code"] == "CONFIG_VALIDATION_FAILED" assert persisted["failed_component"] == "temporal_aggregation" assert persisted["failed_component_version"] == "v1" + + +def test_engine_download_dataset_remote_mode_uses_remote_adapter(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + monkeypatch.setattr( + engine, + "load_workflow_definition", + lambda workflow_id: WorkflowDefinition.model_validate( + { + "workflow_id": workflow_id, + "version": 1, + "steps": [ + {"component": "feature_source"}, + { + "component": "download_dataset", + "config": { + "execution_mode": "remote", + "remote_url": "http://component-host/components/download-dataset", + "remote_retries": 2, + "remote_timeout_sec": 9, + }, + }, + {"component": "spatial_aggregation"}, + {"component": "build_datavalueset"}, + ], + } + ), + ) + monkeypatch.setattr( + engine, + "get_dataset", + lambda dataset_id: {"id": "chirps3_precipitation_daily", "variable": "precip"}, + ) + monkeypatch.setattr( + engine.component_services, + "feature_source_component", + lambda config: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0, 0, 1, 1], + ), + ) + remote_called: dict[str, Any] = {} + + def _remote_adapter(**kwargs: Any) -> None: + remote_called.update(kwargs) + + monkeypatch.setattr(engine, "_invoke_remote_download_component", _remote_adapter) + monkeypatch.setattr( + engine.component_services, + "spatial_aggregation_component", + lambda **kwargs: [{"org_unit": "OU_1", "time": "2024-01-01", "value": 10.0}], + ) + monkeypatch.setattr( + engine.component_services, + "build_datavalueset_component", + lambda **kwargs: ({"dataValues": [{"value": "10.0"}]}, "/tmp/data/out.json"), + ) + monkeypatch.setattr(engine, "persist_run_log", lambda **kwargs: "/tmp/data/workflow_runs/run.json") + + response = engine.execute_workflow(request) + assert response.status == "completed" + assert remote_called["remote_url"] == "http://component-host/components/download-dataset" + assert remote_called["dataset_id"] == "chirps3_precipitation_daily" + + +def test_engine_rejects_remote_download_without_remote_url(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + monkeypatch.setattr( + engine, + "load_workflow_definition", + lambda workflow_id: WorkflowDefinition.model_validate( + { + "workflow_id": workflow_id, + "version": 1, + "steps": [ + {"component": "feature_source"}, + {"component": "download_dataset", "config": {"execution_mode": "remote"}}, + {"component": "spatial_aggregation"}, + {"component": "build_datavalueset"}, + ], + } + ), + ) + monkeypatch.setattr( + engine, + "get_dataset", + lambda dataset_id: {"id": "chirps3_precipitation_daily", "variable": "precip"}, + ) + monkeypatch.setattr( + engine.component_services, + "feature_source_component", + lambda config: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0, 0, 1, 1], + ), + ) + monkeypatch.setattr(engine, "persist_run_log", lambda **kwargs: "/tmp/data/workflow_runs/run.json") + + with pytest.raises(HTTPException) as exc_info: + engine.execute_workflow(request) + + assert exc_info.value.status_code == 422 + detail = cast(dict[str, Any], exc_info.value.detail) + assert detail["error_code"] == "CONFIG_VALIDATION_FAILED" + assert detail["failed_component"] == "download_dataset" + + +def test_engine_rejects_remote_fields_in_local_mode(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + monkeypatch.setattr( + engine, + "load_workflow_definition", + lambda workflow_id: WorkflowDefinition.model_validate( + { + "workflow_id": workflow_id, + "version": 1, + "steps": [ + {"component": "feature_source"}, + { + "component": "download_dataset", + "config": { + "execution_mode": "local", + "remote_url": "http://should-not-be-here/components/download-dataset", + }, + }, + {"component": "spatial_aggregation"}, + {"component": "build_datavalueset"}, + ], + } + ), + ) + monkeypatch.setattr( + engine, + "get_dataset", + lambda dataset_id: {"id": "chirps3_precipitation_daily", "variable": "precip"}, + ) + monkeypatch.setattr( + engine.component_services, + "feature_source_component", + lambda config: ( + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0, 0, 1, 1], + ), + ) + monkeypatch.setattr(engine, "persist_run_log", lambda **kwargs: "/tmp/data/workflow_runs/run.json") + + with pytest.raises(HTTPException) as exc_info: + engine.execute_workflow(request) + + assert exc_info.value.status_code == 422 + detail = cast(dict[str, Any], exc_info.value.detail) + assert detail["error_code"] == "CONFIG_VALIDATION_FAILED" + assert detail["failed_component"] == "download_dataset" + + +def test_engine_supports_remote_mode_for_all_components(monkeypatch: pytest.MonkeyPatch) -> None: + request = WorkflowExecuteRequest.model_validate( + { + "dataset_id": "chirps3_precipitation_daily", + "start": "2024-01-01", + "end": "2024-01-31", + "feature_source": {"source_type": "dhis2_level", "dhis2_level": 3}, + "temporal_aggregation": {"target_period_type": "monthly", "method": "sum"}, + "spatial_aggregation": {"method": "mean"}, + "dhis2": {"data_element_uid": "abc123def45"}, + } + ) + monkeypatch.setattr( + engine, + "load_workflow_definition", + lambda workflow_id: WorkflowDefinition.model_validate( + { + "workflow_id": workflow_id, + "version": 1, + "steps": [ + { + "component": "feature_source", + "config": {"execution_mode": "remote", "remote_url": "http://x/components/feature-source"}, + }, + { + "component": "download_dataset", + "config": { + "execution_mode": "remote", + "remote_url": "http://x/components/download-dataset", + }, + }, + { + "component": "temporal_aggregation", + "config": { + "execution_mode": "remote", + "remote_url": "http://x/components/temporal-aggregation", + }, + }, + { + "component": "spatial_aggregation", + "config": { + "execution_mode": "remote", + "remote_url": "http://x/components/spatial-aggregation", + }, + }, + { + "component": "build_datavalueset", + "config": { + "execution_mode": "remote", + "remote_url": "http://x/components/build-datavalue-set", + }, + }, + ], + } + ), + ) + monkeypatch.setattr( + engine, + "get_dataset", + lambda dataset_id: {"id": "chirps3_precipitation_daily", "variable": "precip"}, + ) + + called: dict[str, bool] = { + "feature": False, + "download": False, + "temporal": False, + "spatial": False, + "build": False, + } + + monkeypatch.setattr( + engine, + "_invoke_remote_feature_source_component", + lambda **kwargs: ( + called.__setitem__("feature", True), + {"type": "FeatureCollection", "features": [{"id": "OU_1", "properties": {"id": "OU_1"}}]}, + [0, 0, 1, 1], + )[1:], + ) + monkeypatch.setattr( + engine, + "_invoke_remote_download_component", + lambda **kwargs: called.__setitem__("download", True), + ) + monkeypatch.setattr( + engine, + "_invoke_remote_temporal_aggregation_component", + lambda **kwargs: (called.__setitem__("temporal", True), {"sizes": {"time": 1}, "dims": ["time"]})[1], + ) + monkeypatch.setattr( + engine, + "_invoke_remote_spatial_aggregation_component", + lambda **kwargs: ( + called.__setitem__("spatial", True), + [{"org_unit": "OU_1", "time": "2024-01-01", "value": 10.0}], + )[1], + ) + monkeypatch.setattr( + engine, + "_invoke_remote_build_datavalueset_component", + lambda **kwargs: ( + called.__setitem__("build", True), + ({"dataValues": [{"value": "10.0"}]}, "/tmp/data/out.json"), + )[1], + ) + monkeypatch.setattr(engine, "persist_run_log", lambda **kwargs: "/tmp/data/workflow_runs/run.json") + + response = engine.execute_workflow(request) + assert response.status == "completed" + assert all(called.values())