Skip to content

Commit

Permalink
qa
Browse files Browse the repository at this point in the history
  • Loading branch information
EddyCMWF committed Aug 30, 2023
1 parent 558c2c6 commit 82d66fe
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
29 changes: 15 additions & 14 deletions earthkit/climate/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from copy import deepcopy

import geopandas as gpd
import pandas as pd
import numpy as np
import pandas as pd
import xarray as xr

from earthkit.climate.tools import (
Expand Down Expand Up @@ -242,16 +242,17 @@ def mask(
A masked data array/dataset with same dimensions as the input dataarray/dataset. Any point that
does not lie in any of the features of geodataframe is masked.
"""
spatial_info = get_spatial_info(
dataarray, lat_key=lat_key, lon_key=lon_key
)
spatial_info = get_spatial_info(dataarray, lat_key=lat_key, lon_key=lon_key)
# Get spatial info required by mask functions:
mask_kwargs.update({key: spatial_info[key] for key in ["lat_key", "lon_key", "regular"]})
mask_kwargs.update(
{key: spatial_info[key] for key in ["lat_key", "lon_key", "regular"]}
)
mask = shapes_to_mask(geodataframe, dataarray, **mask_kwargs)
out = dataarray.where(mask)
out = dataarray.where(mask)
out.attrs.update(geodataframe.attrs)
return out


def masks(
dataarray: T.Union[xr.Dataset, xr.DataArray],
geodataframe: gpd.geodataframe.GeoDataFrame,
Expand Down Expand Up @@ -287,11 +288,11 @@ def masks(
Each slice of layer corresponds to a feature in layer.
"""
masked_arrays = []
spatial_info = get_spatial_info(
dataarray, lat_key=lat_key, lon_key=lon_key
)
spatial_info = get_spatial_info(dataarray, lat_key=lat_key, lon_key=lon_key)
# Get spatial info required by mask functions:
mask_kwargs.update({key: spatial_info[key] for key in ["lat_key", "lon_key", "regular"]})
mask_kwargs.update(
{key: spatial_info[key] for key in ["lat_key", "lon_key", "regular"]}
)

for mask in _shape_mask_iterator(geodataframe, dataarray, **mask_kwargs):
masked_arrays.append(dataarray.where(mask))
Expand Down Expand Up @@ -443,11 +444,11 @@ def _reduce_dataarray(
if isinstance(extra_reduce_dims, str):
extra_reduce_dims = [extra_reduce_dims]

spatial_info = get_spatial_info(
dataarray, lat_key=lat_key, lon_key=lon_key
)
spatial_info = get_spatial_info(dataarray, lat_key=lat_key, lon_key=lon_key)
# Get spatial info required by mask functions:
mask_kwargs.update({key: spatial_info[key] for key in ["lat_key", "lon_key", "regular"]})
mask_kwargs.update(
{key: spatial_info[key] for key in ["lat_key", "lon_key", "regular"]}
)
spatial_dims = spatial_info.get("spatial_dims")

reduce_dims = spatial_dims + extra_reduce_dims
Expand Down
9 changes: 5 additions & 4 deletions earthkit/climate/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def get_dim_key(


def get_spatial_info(dataarray, lat_key=None, lon_key=None):

# Figure out the keys for the latitude and longitude variables
if lat_key is None:
lat_key = get_dim_key(dataarray, "y")
Expand All @@ -223,15 +222,17 @@ def get_spatial_info(dataarray, lat_key=None, lon_key=None):
# will be 'lat' and 'lon'. For irregular data it could be any dimensions
lat_dims = dataarray.coords[lat_key].dims
lon_dims = dataarray.coords[lon_key].dims
spatial_dims = [dim for dim in lat_dims] + [dim for dim in lon_dims if dim not in lat_dims]
spatial_dims = [dim for dim in lat_dims] + [
dim for dim in lon_dims if dim not in lat_dims
]

# Assert that latitude and longitude have the same dimensions
# (irregular data, e.g. x&y or obs)
# or the dimensions are themselves (regular data, 'lat'&'lon')
assert (lat_dims == lon_dims) or (
(lat_dims == (lat_key,)) and (lon_dims) == (lon_key,)
)
if (lat_dims == lon_dims):
if lat_dims == lon_dims:
regular = False
elif (lat_dims == (lat_key,)) and (lon_dims) == (lon_key,):
regular = True
Expand All @@ -245,6 +246,6 @@ def get_spatial_info(dataarray, lat_key=None, lon_key=None):
"lat_key": lat_key,
"lon_key": lon_key,
"regular": regular,
"spatial_dims": spatial_dims
"spatial_dims": spatial_dims,
}
return spatial_info

0 comments on commit 82d66fe

Please sign in to comment.