Skip to content

Commit

Permalink
chunky maskable stew
Browse files Browse the repository at this point in the history
  • Loading branch information
EddyCMWF committed Sep 6, 2023
1 parent feca83e commit dbb45a2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
38 changes: 31 additions & 7 deletions earthkit/climate/aggregate/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def masks(
mask_dim: T.Union[str, None] = None,
lat_key: T.Union[None, str] = None,
lon_key: T.Union[None, str] = None,
chunk: bool = True,
**mask_kwargs,
):
"""
Expand All @@ -281,25 +282,35 @@ def masks(
of the geodataframe
lat_key/lon_key :
key for latitude/longitude variable, default behaviour is to detect variable keys.
chunk : (optional) bool
Boolean to indicate whether to use chunking, default = `True`.
This is advised as spatial.masks can create large results. If you are working with small
arrays, or you have implemented you own chunking rules you may wish to disable it.
Returns
-------
A masked data array with dimensions [feautre_id] + [data.dims].
Each slice of layer corresponds to a feature in layer.
"""
masked_arrays = []
spatial_info = get_spatial_info(dataarray, lat_key=lat_key, lon_key=lon_key)
# Get spatial info required by mask functions:
mask_kwargs.update(
{key: spatial_info[key] for key in ["lat_key", "lon_key", "regular"]}
)

for mask in _shape_mask_iterator(geodataframe, dataarray, **mask_kwargs):
masked_arrays.append(dataarray.where(mask))

mask_dim_index = get_mask_dim_index(mask_dim, geodataframe)

out = xr.concat(masked_arrays, dim=mask_dim_index)
masked_arrays = []
for mask in _shape_mask_iterator(geodataframe, dataarray, **mask_kwargs):
this_masked_array = dataarray.where(mask)
if chunk:
this_masked_array = this_masked_array.chunk()
masked_arrays.append(this_masked_array.copy())
out = xr.concat(masked_arrays, dim=mask_dim_index.name)
if chunk:
out = out.chunk({mask_dim_index.name: 1})

out = out.assign_coords({mask_dim_index.name: mask_dim_index})

out.attrs.update(geodataframe.attrs)

Expand Down Expand Up @@ -379,6 +390,7 @@ def _reduce_dataarray(
how_label: T.Union[str, None] = None,
squeeze: bool = True,
mask_kwargs: T.Dict[str, T.Any] = {},
return_geometry_as_coord: bool = False,
**reduce_kwargs,
):
"""
Expand Down Expand Up @@ -411,6 +423,10 @@ def _reduce_dataarray(
Any kwargs to pass into the mask method
reduce_kwargs :
kwargs recognised by the how function
return_geometry_as_coord :
include the geometries as a coordinate in the returned xarray object. WARNING: geometries are not
serialisable objects, therefore this xarray will not be saveable as netCDF.
Returns
-------
Expand Down Expand Up @@ -501,16 +517,24 @@ def _reduce_dataarray(
elif return_as in ["pandas_compact"]:
# add the reduced data into a new column as a numpy array,
# store the dim information in the attributes

out_dims = {
dim: dataarray.coords.get(dim).values if dim in dataarray.coords else None
dim: dataarray.coords.get(dim).values
if dim in dataarray.coords
else None
for dim in reduced_list[0].dims
}
reduce_attrs[f"{new_short_name}"].update({"dims": out_dims})
reduced_list = [red.values for red in reduced_list]
out = out.assign(**{new_short_name: reduced_list})
out.attrs.update({"reduce_attrs": reduce_attrs})
else:
if return_geometry_as_coord:
out_xr = out_xr.assign_coords(
**{
"geometry": (mask_dim, [geom for geom in geodataframe["geometry"]]),
}
)
out = out_xr.assign_attrs({**geodataframe.attrs, **extra_out_attrs})

return out
1 change: 0 additions & 1 deletion earthkit/climate/aggregate/tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import deepcopy
import functools
import typing as T
from datetime import timedelta
Expand Down

0 comments on commit dbb45a2

Please sign in to comment.