Skip to content

Commit

Permalink
Integrate generalized process filtering in workflow tests
Browse files Browse the repository at this point in the history
- integrate generalized process selection in Skipper
- port workflow tests
- eliminate deprecated `process_levels` fixture
  • Loading branch information
soxofaan committed Jan 25, 2024
1 parent a171cda commit 32296eb
Show file tree
Hide file tree
Showing 20 changed files with 156 additions and 65 deletions.
44 changes: 44 additions & 0 deletions src/openeo_test_suite/lib/internal-tests/test_skipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import openeo
import pytest
from openeo import DataCube

from openeo_test_suite.lib.skipping import extract_processes_from_process_graph


def test_extract_processes_from_process_graph_basic():
pg = {"add35": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
assert extract_processes_from_process_graph(pg) == {"add"}


@pytest.fixture
def s2_cube() -> openeo.DataCube:
return openeo.DataCube.load_collection(
collection_id="S2", bands=["B02", "B03"], connection=None, fetch_metadata=False
)


def test_extract_processes_from_process_graph_cube_simple(s2_cube):
assert extract_processes_from_process_graph(s2_cube) == {"load_collection"}


def test_extract_processes_from_process_graph_cube_reduce_temporal(s2_cube):
cube = s2_cube.reduce_temporal("mean")
assert extract_processes_from_process_graph(cube) == {
"load_collection",
"reduce_dimension",
"mean",
}


def test_extract_processes_from_process_graph_cube_reduce_bands(s2_cube):
b2 = s2_cube.band("B02")
b3 = s2_cube.band("B03")
cube = (b3 - b2) / (b3 + b2)
assert extract_processes_from_process_graph(cube) == {
"load_collection",
"reduce_dimension",
"array_element",
"subtract",
"add",
"divide",
}
85 changes: 68 additions & 17 deletions src/openeo_test_suite/lib/skipping.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from typing import Iterable, List, Union
from typing import Iterable, Iterator, List, Set, Union

import openeo
import pytest
from openeo.internal.graph_building import FlatGraphableMixin, as_flat_graph

_log = logging.getLogger(__name__)

Expand All @@ -13,9 +14,16 @@ class Skipper:
backend capabilities and configuration.
"""

def __init__(self, connection: openeo.Connection, process_levels: List[str]):
def __init__(
self, connection: openeo.Connection, selected_processes: Iterable[str]
):
"""
:param connection: openeo connection
:param selected_processes: list of active process selection
"""
self._connection = connection
self._process_levels = process_levels

self._selected_processes = set(selected_processes)

def _get_output_formats(self) -> set:
formats = set(
Expand All @@ -34,24 +42,67 @@ def skip_if_no_geotiff_support(self):
if not output_formats.intersection({"geotiff", "gtiff"}):
pytest.skip("GeoTIFF not supported as output file format")

def skip_if_unmatching_process_level(self, level: str):
"""Skip test if "process_levels" are set and do not match the given level."""
if len(self._process_levels) > 0 and level not in self._process_levels:
pytest.skip(
f"Skipping {level} test because the specified levels are: {self._process_levels}"
)
def _get_processes(
self, processes: Union[str, List[str], Set[str], openeo.DataCube]
) -> Set[str]:
"""
Generic process id extraction from:
- string (single process id)
- list/set of process ids
- openeo.DataCube: extract process ids from process graph
"""
if isinstance(processes, str):
return {processes}
elif isinstance(processes, (list, set)):
return set(processes)
elif isinstance(processes, openeo.DataCube):
# TODO: wider isinstance check?
return extract_processes_from_process_graph(processes)
else:
raise ValueError(processes)

def skip_if_unselected_process(
self, processes: Union[str, List[str], Set[str], openeo.DataCube]
):
"""
Skip test if any of the provided processes is not in the active process selection.
:param processes: single process id, list/set of process ids or an `openeo.DataCube` to extract process ids from
"""
# TODO: automatically call this skipper from monkey-patched `cube.download()`?
processes = self._get_processes(processes)
unselected_processes = processes.difference(self._selected_processes)
if unselected_processes:
pytest.skip(f"Process selection does not cover: {unselected_processes}")

def skip_if_unsupported_process(self, processes: Union[str, Iterable[str]]):
def skip_if_unsupported_process(
self, processes: Union[str, List[str], Set[str], openeo.DataCube]
):
"""
Skip test if any of the provided processes is not supported by the backend.
@param processes: single process id or list of process ids
:param processes: single process id, list/set of process ids or an `openeo.DataCube` to extract process ids from
"""
if isinstance(processes, str):
processes = [processes]
processes = self._get_processes(processes)

# TODO: cache available processes?
available_processes = set(p["id"] for p in self._connection.list_processes())
unsupported_processes = set(processes).difference(available_processes)
unsupported_processes = processes.difference(available_processes)
if unsupported_processes:
pytest.skip(
f"Skipping test because backend does not support: {unsupported_processes}"
)
pytest.skip(f"Backend does not support: {unsupported_processes}")


def extract_processes_from_process_graph(
pg: Union[dict, FlatGraphableMixin]
) -> Set[str]:
"""Extract process ids from given process graph."""
pg = as_flat_graph(pg)

def extract(pg) -> Iterator[str]:
for v in pg.values():
yield v["process_id"]
for arg in v["arguments"].values():
if isinstance(arg, dict) and "process_graph" in arg:
yield from extract(arg["process_graph"])

return set(extract(pg))
27 changes: 9 additions & 18 deletions src/openeo_test_suite/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,12 @@
import pytest

from openeo_test_suite.lib.backend_under_test import get_backend_url
from openeo_test_suite.lib.process_selection import get_selected_processes
from openeo_test_suite.lib.skipping import Skipper

_log = logging.getLogger(__name__)


@pytest.fixture(scope="session")
def process_levels(request) -> List[str]:
"""
Fixture to get the desired openEO profiles levels.
"""
# TODO: eliminate this fixture?
levels_str = request.config.getoption("--process-levels")

if isinstance(levels_str, str) and len(levels_str) > 0:
_log.info(f"Testing process levels {levels_str!r}")
return list(map(lambda l: l.strip(), levels_str.split(",")))
else:
return []


@pytest.fixture(scope="module")
def auto_authenticate() -> bool:
"""
Expand All @@ -38,7 +24,9 @@ def auto_authenticate() -> bool:


@pytest.fixture(scope="module")
def connection(request, auto_authenticate: bool, pytestconfig) -> openeo.Connection:
def connection(
request, auto_authenticate: bool, pytestconfig: pytest.Config
) -> openeo.Connection:
backend_url = get_backend_url(request.config, required=True)
con = openeo.connect(backend_url, auto_validate=False)

Expand Down Expand Up @@ -68,5 +56,8 @@ def connection(request, auto_authenticate: bool, pytestconfig) -> openeo.Connect


@pytest.fixture
def skipper(connection, process_levels) -> Skipper:
return Skipper(connection=connection, process_levels=process_levels)
def skipper(connection) -> Skipper:
return Skipper(
connection=connection,
selected_processes=[p.process_id for p in get_selected_processes()],
)
2 changes: 1 addition & 1 deletion src/openeo_test_suite/tests/workflows/L1/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def test_apply(
tmp_path,
):
skipper.skip_if_no_netcdf_support()
skipper.skip_if_unmatching_process_level(level=LEVEL)

filename = tmp_path / "test_apply.nc"
cube = cube_one_day_red.apply(lambda x: x.clip(0, 1))
skipper.skip_if_unselected_process(cube)
cube.download(filename)

assert filename.exists()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def test_apply_dimension_quantiles_0(
number of labels will be equal to the number of values computed by the process.
"""
skipper.skip_if_no_netcdf_support()
skipper.skip_if_unmatching_process_level(level=LEVEL)

filename = tmp_path / "test_apply_dimension_quantiles_0.nc"
b_dim = collection_dims["b_dim"]
Expand All @@ -33,7 +32,7 @@ def test_apply_dimension_quantiles_0(
process=lambda d: quantiles(d, probabilities=[0.5, 0.75]),
dimension=b_dim,
)

skipper.skip_if_unselected_process(cube)
cube.download(filename)
assert filename.exists()
data = load_netcdf_dataarray(filename, band_dim_name=b_dim)
Expand All @@ -54,7 +53,6 @@ def test_apply_dimension_quantiles_1(
tmp_path,
):
skipper.skip_if_no_netcdf_support()
skipper.skip_if_unmatching_process_level(level=LEVEL)

filename = tmp_path / "test_apply_dimension_quantiles_1.nc"
b_dim = collection_dims["b_dim"]
Expand All @@ -67,7 +65,7 @@ def test_apply_dimension_quantiles_1(
dimension=t_dim,
target_dimension=b_dim,
)
print(filename)
skipper.skip_if_unselected_process(cube)
cube.download(filename)
assert filename.exists()
data = load_netcdf_dataarray(filename, band_dim_name=b_dim)
Expand All @@ -90,7 +88,6 @@ def test_apply_dimension_ndvi(
tmp_path,
):
skipper.skip_if_no_netcdf_support()
skipper.skip_if_unmatching_process_level(level=LEVEL)

filename = tmp_path / "test_apply_dimension_ndvi.nc"
b_dim = collection_dims["b_dim"]
Expand All @@ -104,6 +101,7 @@ def compute_ndvi(data):
return array_concat(data, ndvi)

ndvi = cube_one_day_red_nir.apply_dimension(dimension=b_dim, process=compute_ndvi)
skipper.skip_if_unselected_process(ndvi)
ndvi.download(filename)

assert filename.exists()
Expand Down
4 changes: 3 additions & 1 deletion src/openeo_test_suite/tests/workflows/L1/test_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def test_load_save_netcdf(
tmp_path,
):
skipper.skip_if_no_netcdf_support()
skipper.skip_if_unmatching_process_level(level=LEVEL)

filename = tmp_path / "test_load_save_netcdf.nc"
b_dim = collection_dims["b_dim"]
x_dim = collection_dims["x_dim"]
y_dim = collection_dims["y_dim"]
t_dim = collection_dims["t_dim"]

skipper.skip_if_unselected_process(cube_red_nir)
cube_red_nir.download(filename)

assert filename.exists()
Expand Down Expand Up @@ -64,6 +64,7 @@ def test_load_save_10x10_netcdf(
y_dim = collection_dims["y_dim"]
t_dim = collection_dims["t_dim"]

skipper.skip_if_unselected_process(cube_red_10x10)
cube_red_10x10.download(filename)

assert filename.exists()
Expand Down Expand Up @@ -105,6 +106,7 @@ def test_load_save_geotiff(
y_dim = collection_dims["y_dim"]
t_dim = collection_dims["t_dim"]

skipper.skip_if_unselected_process(cube_one_day_red)
cube_one_day_red.download(filename)

assert filename.exists()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def test_ndvi_index(
tmp_path,
):
skipper.skip_if_no_netcdf_support()
skipper.skip_if_unmatching_process_level(level=LEVEL)

filename = tmp_path / "test_ndvi_index.nc"
b_dim = collection_dims["b_dim"]
Expand All @@ -23,6 +22,7 @@ def compute_ndvi(data):
return (nir - red) / (nir + red)

ndvi = cube_one_day_red_nir.reduce_dimension(dimension=b_dim, reducer=compute_ndvi)
skipper.skip_if_unselected_process(ndvi)
ndvi.download(filename)

assert filename.exists()
Expand All @@ -49,6 +49,7 @@ def compute_ndvi(data):
return (nir - red) / (nir + red)

ndvi = cube_one_day_red_nir.reduce_dimension(dimension=b_dim, reducer=compute_ndvi)
skipper.skip_if_unselected_process(ndvi)
ndvi.download(filename)

assert filename.exists()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def test_boolean_mask(
tmp_path,
):
skipper.skip_if_no_netcdf_support()
skipper.skip_if_unmatching_process_level(level=LEVEL)

filename = tmp_path / "test_boolean_mask.nc"
b_dim = collection_dims["b_dim"]
Expand All @@ -26,6 +25,7 @@ def compute_ndvi(data):

ndvi = cube_one_day_red_nir.reduce_dimension(dimension=b_dim, reducer=compute_ndvi)
ndvi_mask = ndvi > 0.75
skipper.skip_if_unselected_process(ndvi_mask)
ndvi_mask.download(filename)

assert filename.exists()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def test_reduce_time(
tmp_path,
):
skipper.skip_if_no_netcdf_support()
skipper.skip_if_unmatching_process_level(level=LEVEL)

filename = tmp_path / "test_reduce_time.nc"
b_dim = collection_dims["b_dim"]
t_dim = collection_dims["t_dim"]

cube = cube_red_nir.reduce_dimension(dimension=t_dim, reducer="mean")
skipper.skip_if_unselected_process(cube)
cube.download(filename)

assert filename.exists()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def test_ndvi_add_dim(
tmp_path,
):
skipper.skip_if_no_netcdf_support()
skipper.skip_if_unmatching_process_level(level=LEVEL)

filename = tmp_path / "test_ndvi_add_dim.nc"
b_dim = collection_dims["b_dim"]
Expand All @@ -22,6 +21,7 @@ def compute_ndvi(data):

ndvi = cube_one_day_red_nir.reduce_dimension(dimension=b_dim, reducer=compute_ndvi)
ndvi = ndvi.add_dimension(type="bands", name=b_dim, label="NDVI")
skipper.skip_if_unselected_process(ndvi)
ndvi.download(filename)

assert filename.exists()
Expand Down
Loading

0 comments on commit 32296eb

Please sign in to comment.