Skip to content

Commit

Permalink
updates to methods
Browse files Browse the repository at this point in the history
  • Loading branch information
EddyCMWF committed Aug 30, 2023
1 parent 82d66fe commit 4d20ee3
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 34 deletions.
18 changes: 9 additions & 9 deletions earthkit/climate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
except ImportError:
pass
else:
KWARG_TYPES = {
# "dataarray": xr.DataArray,
# "dataset": xr.Dataset,
}
aggregate = transform_module_inputs(aggregate)

aggregate = transform_module_inputs(aggregate, kwarg_types=KWARG_TYPES)
climatology = transform_module_inputs(climatology)

climatology = transform_module_inputs(climatology, kwarg_types=KWARG_TYPES)
shapes = transform_module_inputs(shapes)

shapes = transform_module_inputs(shapes, kwarg_types=KWARG_TYPES)

__all__ = ["__version__", "aggregate", "climatology", "shapes"]
__all__ = [
"__version__",
"aggregate",
"climatology",
"shapes"
]
71 changes: 56 additions & 15 deletions earthkit/climate/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
}


def daily_mean(dataarray, **kwargs):
def daily_mean(dataarray: T.Union[xr.Dataset, xr.DataArray], **kwargs):
"""
Calculate the daily mean.
Expand All @@ -46,7 +46,7 @@ def daily_mean(dataarray, **kwargs):
return resample(dataarray, frequency="D", dim="time", how="mean", **kwargs)


def daily_max(dataarray, **kwargs):
def daily_max(dataarray: T.Union[xr.Dataset, xr.DataArray], **kwargs):
"""
Calculate the daily max.
Expand All @@ -64,7 +64,7 @@ def daily_max(dataarray, **kwargs):
return resample(dataarray, frequency="D", dim="time", how="max", **kwargs)


def daily_min(dataarray, **kwargs):
def daily_min(dataarray: T.Union[xr.Dataset, xr.DataArray], **kwargs):
"""
Calculate the daily min.
Expand All @@ -82,7 +82,7 @@ def daily_min(dataarray, **kwargs):
return resample(dataarray, frequency="D", dim="time", how="min", **kwargs)


def monthly_mean(dataarray, **kwargs):
def monthly_mean(dataarray: T.Union[xr.Dataset, xr.DataArray], **kwargs):
"""
Calculate the monthly mean.
Expand All @@ -101,7 +101,7 @@ def monthly_mean(dataarray, **kwargs):


def resample(
dataarray: xr.DataArray,
dataarray: T.Union[xr.Dataset, xr.DataArray],
frequency: str or int or float,
dim: str = "time",
how: str = "mean",
Expand Down Expand Up @@ -140,7 +140,7 @@ def resample(


def _groupby_time(
dataarray: xr.DataArray,
dataarray: T.Union[xr.Dataset, xr.DataArray],
frequency: str = None,
bin_widths: int = None,
squeeze: bool = True,
Expand Down Expand Up @@ -175,7 +175,7 @@ def _groupby_time(


def _groupby_bins(
dataarray: xr.DataArray,
dataarray: T.Union[xr.Dataset, xr.DataArray],
frequency: str,
bin_widths: int,
squeeze: bool,
Expand Down Expand Up @@ -280,7 +280,7 @@ def _reduce_dataarray(

def reduce(
dataarray: T.Union[xr.DataArray, xr.Dataset],
**kwargs,
*args, **kwargs,
):
"""
Reduce an xarray.dataarray or xarray.dataset using a specified `how` method.
Expand Down Expand Up @@ -314,14 +314,57 @@ def reduce(
"""
if isinstance(dataarray, (xr.Dataset)):
return xr.Dataset(
[_reduce_dataarray(dataarray[var], **kwargs) for var in dataarray.data_vars]
)
out_ds = xr.Dataset().assign_attrs(dataarray.attrs)
for var in dataarray.data_vars:
out_da = _reduce_dataarray(dataarray[var], *args, **kwargs)
out_ds[out_da.name] = out_da
return out_ds
else:
return _reduce_dataarray(dataarray, **kwargs)
return _reduce_dataarray(dataarray, *args, **kwargs)


def rolling_reduce(
dataarray: T.Union[xr.Dataset, xr.DataArray], *args, **kwargs
) -> xr.DataArray:
"""Return reduced data using a moving window over which to apply the reduction.
Parameters
----------
dataarray : xr.DataArray or xr.Dataset
Data over which the moving window is applied according to the reduction method.
windows :
windows for the rolling groups, for example `time=10` to perform a reduction
in the time dimension with a bin size of 10. the rolling groups can be defined
over any number of dimensions. **see documentation for xarray.dataarray.rolling**.
min_periods : integer
The minimum number of observations in the window required to have a value
(otherwise result is NaN). Default is to set **min_periods** equal to the size of the window.
**see documentation for xarray.dataarray.rolling**
center : bool
Set the labels at the centre of the window, **see documentation for xarray.dataarray.rolling**.
how_reduce : str,
Function to be applied for reduction. Default is 'mean'.
how_dropna : str
Determine if dimension is removed from the output when we have at least one NaN or
all NaN. **how_dropna** can be 'None', 'any' or 'all'. Default is 'any'.
**kwargs :
Any kwargs that are compatible with the select `how_reduce` method.
Returns
-------
xr.DataArray or xr.Dataset (as provided)
"""
if isinstance(dataarray, (xr.Dataset)):
out_ds = xr.Dataset().assign_attrs(dataarray.attrs)
for var in dataarray.data_vars:
out_da = _rolling_reduce_dataarray(dataarray[var], *args, **kwargs)
out_ds[out_da.name] = out_da
return out_ds
else:
return _rolling_reduce_dataarray(dataarray, *args, **kwargs)


def _rolling_reduce_dataarray(
dataarray: xr.DataArray, how_reduce="mean", how_dropna="any", **kwargs
) -> xr.DataArray:
"""Return reduced data using a moving window over which to apply the reduction.
Expand Down Expand Up @@ -363,12 +406,10 @@ def rolling_reduce(

# Any kwargs left after above reductions are kwargs for reduction method
reduce_kwargs = kwargs
# print("rolling kwargs: ", rolling_kwargs)
# Create rolling groups:
data_rolling = dataarray.rolling(**rolling_kwargs)
# print("reduce kwargs: ", reduce_kwargs)

data_windowed = reduce(data_rolling, how=how_reduce, **reduce_kwargs)
data_windowed = _reduce_dataarray(data_rolling, how=how_reduce, **reduce_kwargs)

data_windowed = _dropna(data_windowed, window_dims, how_dropna)

Expand Down
19 changes: 9 additions & 10 deletions earthkit/climate/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def masks(
def reduce(
dataarray: T.Union[xr.Dataset, xr.DataArray],
geodataframe: gpd.GeoDataFrame,
**kwargs,
*args, **kwargs,
):
"""
Apply a shape object to an xarray.DataArray object using the specified 'how' method.
Expand Down Expand Up @@ -348,20 +348,19 @@ def reduce(
"""
if isinstance(dataarray, xr.Dataset):
if kwargs.get("return_as", "pandas") in ["xarray"]:
return xr.Dataset(
[
_reduce_dataarray(dataarray[var], geodataframe, **kwargs)
for var in dataarray.data_vars
]
)
if kwargs.get("return_as", "xarray") in ["xarray"]:
out_ds = xr.Dataset().assign_attrs(dataarray.attrs)
for var in dataarray.data_vars:
out_da = _reduce_dataarray(dataarray[var], *args, **kwargs)
out_ds[out_da.name] = out_da
return out_ds
else:
out = geodataframe
for var in dataarray.data_vars:
out = _reduce_dataarray(dataarray[var], geodataframe, **kwargs)
out = _reduce_dataarray(dataarray[var], geodataframe, *args, **kwargs)
return out
else:
return _reduce_dataarray(dataarray, geodataframe, **kwargs)
return _reduce_dataarray(dataarray, geodataframe, *args, **kwargs)


def _reduce_dataarray(
Expand Down

0 comments on commit 4d20ee3

Please sign in to comment.