diff --git a/src/idpi/grib_decoder.py b/src/idpi/grib_decoder.py index 135dae0d..b17f5538 100644 --- a/src/idpi/grib_decoder.py +++ b/src/idpi/grib_decoder.py @@ -5,11 +5,9 @@ import sys import typing from contextlib import contextmanager -from functools import partial from pathlib import Path # Third-party -import dask import earthkit.data # type: ignore import eccodes # type: ignore import numpy as np @@ -17,7 +15,10 @@ # First-party import idpi.config -from idpi.product import ProductDescriptor + +# Local +from . import tasking +from .product import ProductDescriptor DIM_MAP = { "level": "z", @@ -147,7 +148,6 @@ def __init__( """ self._datafiles = [str(p) for p in datafiles] - self._delayed = partial(dask.delayed, pure=True) if delay else (lambda x: x) if idpi.config.get("data_scope", "cosmo") == "cosmo": with cosmo_grib_defs(): self._grid = self.load_grid_reference(ref_param) @@ -295,7 +295,7 @@ def _load_dataset( result = {} for param in _params: - result[param] = self._delayed(self._load_param)(param) # type: ignore + result[param] = tasking.delayed(self._load_param)(param) if not _params == result.keys(): raise RuntimeError(f"Missing params: {_params - data.keys()}") diff --git a/src/idpi/product.py b/src/idpi/product.py index e9721e74..c1e0297e 100644 --- a/src/idpi/product.py +++ b/src/idpi/product.py @@ -1,11 +1,10 @@ """Product base classes.""" - # Standard library import dataclasses as dc from abc import ABCMeta, abstractmethod -# Third-party -import dask +# Local +from . import tasking @dc.dataclass @@ -30,7 +29,7 @@ def _run(self, **args): def __call__(self, *args): if self._delay_entire_product: - return dask.delayed(self._run, pure=True)(*args) + return tasking.delayed(self._run)(*args) else: return self._run(*args) diff --git a/src/idpi/tasking.py b/src/idpi/tasking.py new file mode 100644 index 00000000..82172732 --- /dev/null +++ b/src/idpi/tasking.py @@ -0,0 +1,18 @@ +"""functionality for tasking and parallel computing.""" +# Third-party +import dask + +# First-party +import idpi.config + + +def delayed(fn): + return dask.delayed(fn, pure=True) if idpi.config.get("enable_dask", False) else fn + + +def compute(*delayed_objs): + return ( + dask.compute(*delayed_objs) + if idpi.config.get("enable_dask", False) + else delayed_objs + )