diff --git a/doc/api.rst b/doc/api.rst index 63733318..091e0c30 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -141,6 +141,8 @@ Process introspection and variables variable_info filter_variables + + Process runtime methods ----------------------- @@ -156,6 +158,7 @@ Variable :toctree: _api_generated/ variable + MAIN_CLOCK index any_object foreign diff --git a/doc/whats_new.rst b/doc/whats_new.rst index fe0061b9..e0781cac 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -9,11 +9,13 @@ v0.6.0 (Unreleased) ``main_clock``, ``main_clock_dim`` and ``main_clock_coords`` and all occurences of ``master`` to ``main`` in the rest of the codebase. all ``master...`` API hooks are still working, but raise a Futurewarning + - Added access to main clock in initialize step as ``main_clock_values`` + and as a ``xr.DataArray``: ``main_clock_array``. for refering to the main + clock as a dimension label, the placeholder ``xs.MAIN_CLOCK`` can be used. + This will be set to the main clock when storing the dataset. - Changed default ``fill_value`` in the zarr stores to maximum dtype value for integer dtypes and ``np.nan`` for floating-point variables. - - v0.5.0 (26 January 2021) ------------------------ diff --git a/xsimlab/__init__.py b/xsimlab/__init__.py index f0399169..4a842ff3 100644 --- a/xsimlab/__init__.py +++ b/xsimlab/__init__.py @@ -24,6 +24,7 @@ group, group_dict, ) +from .utils import MAIN_CLOCK from .xr_accessor import SimlabAccessor, create_setup from . import monitoring diff --git a/xsimlab/drivers.py b/xsimlab/drivers.py index 5c8e0c00..23e12470 100644 --- a/xsimlab/drivers.py +++ b/xsimlab/drivers.py @@ -7,7 +7,7 @@ from .hook import flatten_hooks, group_hooks, RuntimeHook from .process import RuntimeSignal from .stores import ZarrSimulationStore -from .utils import get_batch_size +from .utils import get_batch_size, MAIN_CLOCK class ValidateOption(Enum): @@ -28,6 +28,8 @@ class RuntimeContext(Mapping[str, Any]): "batch", "sim_start", "sim_end", + "main_clock_values", + "main_clock_dataarray", "step", "nsteps", "step_start", @@ -161,6 +163,8 @@ def _generate_runtime_datasets(dataset): init_data_vars = { "_sim_start": mclock_coord[0], "_nsteps": dataset.xsimlab.nsteps, + # since we pass a dataset, we need to set the coords + "_main_clock_values": dataset.coords[mclock_dim].data, "_sim_end": mclock_coord[-1], } @@ -327,6 +331,8 @@ def _run( sim_start=ds_init["_sim_start"].values, nsteps=ds_init["_nsteps"].values, sim_end=ds_init["_sim_end"].values, + main_clock_values=ds_init["_main_clock_values"].values, + main_clock_dataarray=dataset.xsimlab.main_clock_coord, ) in_vars = _get_input_vars(ds_init, model) diff --git a/xsimlab/stores.py b/xsimlab/stores.py index f23479bd..752cff01 100644 --- a/xsimlab/stores.py +++ b/xsimlab/stores.py @@ -7,7 +7,7 @@ import zarr from . import Model -from .utils import get_batch_size, normalize_encoding +from .utils import get_batch_size, normalize_encoding, MAIN_CLOCK from .variable import VarType @@ -252,6 +252,14 @@ def _create_zarr_dataset( f"its accepted dimension(s): {var_info['metadata']['dims']}" ) + # set MAIN_CLOCK placeholder to main_clock dimension + if self.mclock_dim in dim_labels and MAIN_CLOCK in dim_labels: + raise ValueError( + f"Main clock: '{self.mclock_dim}' has a duplicate in {dim_labels}." + "Please change the name of 'main_clock' in `create_setup`" + ) + dim_labels = [self.mclock_dim if d is MAIN_CLOCK else d for d in dim_labels] + if clock is not None: dim_labels.insert(0, clock) if add_batch_dim: @@ -331,7 +339,6 @@ def write_output_vars(self, batch: int, step: int, model: Optional[Model] = None else: idx_dims = [clock_inc] + [slice(0, n) for n in np.shape(value)] - if batch != -1: idx_dims.insert(0, batch) diff --git a/xsimlab/tests/test_model.py b/xsimlab/tests/test_model.py index 90d99bf4..9228db5d 100644 --- a/xsimlab/tests/test_model.py +++ b/xsimlab/tests/test_model.py @@ -452,3 +452,68 @@ def initialize(self): model.execute("initialize", {}) assert model.state[("baz", "actual")] == Frozen({("foo", "a"): 1, ("bar", "b"): 2}) + + +def test_main_clock_access(): + @xs.process + class Foo: + a = xs.variable(intent="out", dims=xs.MAIN_CLOCK) + b = xs.variable(intent="out", dims=xs.MAIN_CLOCK) + + @xs.runtime(args=["main_clock_values", "main_clock_dataarray"]) + def initialize(self, clock_values, clock_array): + self.a = clock_values * 2 + np.testing.assert_equal(self.a, [0, 2, 4, 6]) + self.b = clock_array * 2 + assert clock_array.dims[0] == "clock" + assert all(clock_array[clock_array.dims[0]].data == [0, 1, 2, 3]) + + @xs.runtime(args=["step_delta", "step"]) + def run_step(self, dt, n): + assert self.a[n] == 2 * n + self.a[n] += 1 + + model = xs.Model({"foo": Foo}) + ds_in = xs.create_setup( + model=model, + clocks={"clock": range(4)}, + input_vars={}, + output_vars={"foo__a": None}, + ) + ds_out = ds_in.xsimlab.run(model=model) + assert all(ds_out.foo__a.data == [1, 3, 5, 6]) + + # test for error when another dim has the same name as xs.MAIN_CLOCK + @xs.process + class DoubleMainClockDim: + a = xs.variable(intent="out", dims=("clock", xs.MAIN_CLOCK)) + + def initialize(self): + self.a = [[1, 2, 3], [3, 4, 5]] + + def run_step(self): + self.a += self.a + + model = xs.Model({"foo": DoubleMainClockDim}) + with pytest.raises(ValueError, match=r"Main clock:*"): + xs.create_setup( + model=model, + clocks={"clock": [0, 1, 2, 3]}, + input_vars={}, + output_vars={"foo__a": None}, + ).xsimlab.run(model) + + # test for error when trying to put xs.MAIN_CLOCK as a dim in an input var + with pytest.raises( + ValueError, match="Do not pass xs.MAIN_CLOCK into input vars dimensions" + ): + a = xs.variable(intent="in", dims=xs.MAIN_CLOCK) + + with pytest.raises( + ValueError, match="Do not pass xs.MAIN_CLOCK into input vars dimensions" + ): + b = xs.variable(intent="in", dims=(xs.MAIN_CLOCK,)) + with pytest.raises( + ValueError, match="Do not pass xs.MAIN_CLOCK into input vars dimensions" + ): + c = xs.variable(intent="in", dims=["a", ("a", xs.MAIN_CLOCK)]) diff --git a/xsimlab/tests/test_xr_accessor.py b/xsimlab/tests/test_xr_accessor.py index 42924be8..17cec751 100644 --- a/xsimlab/tests/test_xr_accessor.py +++ b/xsimlab/tests/test_xr_accessor.py @@ -158,9 +158,9 @@ def test_master_clock_coords_warning(self): ) with pytest.warns( FutureWarning, - match="master_clock is to be deprecated in favour of main_clock", + match="master_clock_coord is to be deprecated in favour of main_clock", ): - xr.testing.assert_equal(ds.xsimlab.master_clock_coord, ds.mclock) + ds.xsimlab.master_clock_coord def test_clock_sizes(self): ds = xr.Dataset( diff --git a/xsimlab/utils.py b/xsimlab/utils.py index 64e9416b..bbf2ddd6 100644 --- a/xsimlab/utils.py +++ b/xsimlab/utils.py @@ -15,6 +15,34 @@ V = TypeVar("V") +class _MainClockDim: + """Singleton class to be used as a placeholder of the main clock + dimension. + + It will be replaced by the actual dimension label set during simulation setup + (i.e., ``main_clock`` argument). + + """ + + _singleton = None + + def __new__(cls): + if _MainClockDim._singleton is None: + # if there is no instance of it yet, create a class instance + _MainClockDim._singleton = super(_MainClockDim, cls).__new__(cls) + return _MainClockDim._singleton + + def __repr__(self): + return "MAIN_CLOCK (undefined)" + + +MAIN_CLOCK = _MainClockDim() +""" +Sentinel to indicate simulation's main clock dimension, to be +replaced by the actual dimension label set in input/output datasets. +""" + + def variables_dict(process_cls): """Get all xsimlab variables declared in a process. diff --git a/xsimlab/variable.py b/xsimlab/variable.py index 95823cf6..d6b642a1 100644 --- a/xsimlab/variable.py +++ b/xsimlab/variable.py @@ -5,7 +5,7 @@ import attr from attr._make import _CountingAttr -from .utils import normalize_encoding +from .utils import normalize_encoding, MAIN_CLOCK class VarType(Enum): @@ -57,12 +57,18 @@ def _as_dim_tuple(dims): ambiguous and thus not allowed. """ - if not len(dims): + # MAIN_CLOCK is sentinel and does not have length (or zero), so check explicitly + if dims is MAIN_CLOCK: + dims = [(dims,)] + elif not len(dims): dims = [()] elif isinstance(dims, str): dims = [(dims,)] elif isinstance(dims, list): - dims = [tuple([d]) if isinstance(d, str) else tuple(d) for d in dims] + dims = [ + tuple([d]) if (isinstance(d, str) or d is MAIN_CLOCK) else tuple(d) + for d in dims + ] else: dims = [dims] @@ -221,6 +227,9 @@ def variable( else: _init = True _repr = True + # also check if MAIN_CLOCK is there + if any([MAIN_CLOCK in d for d in metadata["dims"]]): + raise ValueError("Do not pass xs.MAIN_CLOCK into input vars dimensions") return attr.attrib( metadata=metadata, diff --git a/xsimlab/xr_accessor.py b/xsimlab/xr_accessor.py index d9c20a5e..a12343da 100644 --- a/xsimlab/xr_accessor.py +++ b/xsimlab/xr_accessor.py @@ -209,7 +209,7 @@ def master_clock_coord(self): Returns None if no main clock is defined in the dataset. """ warnings.warn( - "master_clock is to be deprecated in favour of main_clock", + "master_clock_coord is to be deprecated in favour of main_clock", FutureWarning, ) return self.main_clock_coord @@ -224,7 +224,7 @@ def main_clock_coord(self): @property def nsteps(self): - """Number of simulation steps, computed from the master + """Number of simulation steps, computed from the main clock coordinate. Returns 0 if no main clock is defined in the dataset. @@ -319,6 +319,7 @@ def _uniformize_clock_coords(self, dim=None, units=None, calendar=None): ) def _set_input_vars(self, model, input_vars): + invalid_inputs = set(input_vars) - set(model.input_vars) if invalid_inputs: