Skip to content
This repository has been archived by the owner on May 2, 2024. It is now read-only.

Commit

Permalink
Config dask (#86)
Browse files Browse the repository at this point in the history
## Purpose
use the config to set the property "enable_dask" to avoid having to pass it as args to various functions and avoid inconsistencies.

## Code Changes

* Get the property enable_dask via the config.
* Hide dask behind a tasking module, so that dask is not a dependency of any of the products. This way we adopt a task parallel model that can be configure with any tasking runtime system.
  • Loading branch information
cosunae authored Oct 16, 2023
1 parent 0f4335f commit dc8cfe0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
10 changes: 5 additions & 5 deletions src/idpi/grib_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
import sys
import typing
from contextlib import contextmanager
from functools import partial
from pathlib import Path

# Third-party
import dask
import earthkit.data # type: ignore
import eccodes # type: ignore
import numpy as np
import xarray as xr

# First-party
import idpi.config
from idpi.product import ProductDescriptor

# Local
from . import tasking
from .product import ProductDescriptor

DIM_MAP = {
"level": "z",
Expand Down Expand Up @@ -147,7 +148,6 @@ def __init__(
"""
self._datafiles = [str(p) for p in datafiles]
self._delayed = partial(dask.delayed, pure=True) if delay else (lambda x: x)
if idpi.config.get("data_scope", "cosmo") == "cosmo":
with cosmo_grib_defs():
self._grid = self.load_grid_reference(ref_param)
Expand Down Expand Up @@ -295,7 +295,7 @@ def _load_dataset(
result = {}

for param in _params:
result[param] = self._delayed(self._load_param)(param) # type: ignore
result[param] = tasking.delayed(self._load_param)(param)

if not _params == result.keys():
raise RuntimeError(f"Missing params: {_params - data.keys()}")
Expand Down
7 changes: 3 additions & 4 deletions src/idpi/product.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Product base classes."""

# Standard library
import dataclasses as dc
from abc import ABCMeta, abstractmethod

# Third-party
import dask
# Local
from . import tasking


@dc.dataclass
Expand All @@ -30,7 +29,7 @@ def _run(self, **args):

def __call__(self, *args):
if self._delay_entire_product:
return dask.delayed(self._run, pure=True)(*args)
return tasking.delayed(self._run)(*args)
else:
return self._run(*args)

Expand Down
18 changes: 18 additions & 0 deletions src/idpi/tasking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""functionality for tasking and parallel computing."""
# Third-party
import dask

# First-party
import idpi.config


def delayed(fn):
return dask.delayed(fn, pure=True) if idpi.config.get("enable_dask", False) else fn


def compute(*delayed_objs):
return (
dask.compute(*delayed_objs)
if idpi.config.get("enable_dask", False)
else delayed_objs
)

0 comments on commit dc8cfe0

Please sign in to comment.