diff --git a/ewatercycle/models/abstract.py b/ewatercycle/models/abstract.py index 3aafb35b..0d100025 100644 --- a/ewatercycle/models/abstract.py +++ b/ewatercycle/models/abstract.py @@ -157,6 +157,40 @@ def get_value_as_xarray(self, name: str) -> xr.DataArray: """ + def set_value_as_xarray(self, name: str, value: xr.DataArray) -> None: + """Specify a new value for a model variable by passing in an xarray.DataArray. + + The DataArray should have the same grid as the model. + + Args: + name: Name of variable + value: xarray.DataArray with values for the specified variable. + + """ + model_grid = self.bmi.get_var_grid(name) + model_shape = self.bmi.get_grid_shape(model_grid) + if model_shape == value.shape: + self.bmi.set_value(name, value.data.flatten()) + elif model_shape == value.T.shape: + self.bmi.set_value(name, value.T.data.flatten()) + else: + raise ValueError( + f"Shape mismatch. Model has shape {model_shape}, but" + f"input has shape {value.shape}." + ) + # TODO: what if data is NaN? Model-specific back-conversion? Don't convert to NaN in getters? + # TODO: wflow not settable (rank mismatch, should be 2???!) + # TODO: pcrglob discharge not settable (but works for temperature) + # TODO: add tests like so: + """ + name = 'temperature' + da = model.get_value_as_xarray(name) + orig = model.get_value(name) + model.set_value_as_xarray(name, da) + new = model.get_value(name) + assert all(orig == new) + """ + @property @abstractmethod def parameters(self) -> Iterable[Tuple[str, Any]]: