Skip to content

Commit

Permalink
Add subsetting option to to_numpy method
Browse files Browse the repository at this point in the history
  • Loading branch information
sandorkertesz committed Jun 18, 2024
1 parent 721fe99 commit 9c76c22
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 22 deletions.
83 changes: 61 additions & 22 deletions src/earthkit/data/core/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _metadata(self):
self.__metadata = self._make_metadata()
return self.__metadata

def to_numpy(self, flatten=False, dtype=None):
def to_numpy(self, flatten=False, dtype=None, index=None):
r"""Return the values stored in the field as an ndarray.
Parameters
Expand All @@ -137,6 +137,9 @@ def to_numpy(self, flatten=False, dtype=None):
dtype: str, numpy.dtype or None
Typecode or data-type of the array. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is ``float64``.
index: ndarray indexing object, optional
The index of the values and to be extracted. When it
is None all the values are extracted
Returns
-------
Expand All @@ -148,10 +151,12 @@ def to_numpy(self, flatten=False, dtype=None):
v = numpy_backend().to_array(v, self.raw_values_backend)
shape = self._required_shape(flatten)
if shape != v.shape:
return v.reshape(shape)
v = v.reshape(shape)
if index is not None:
v = v[index]
return v

def to_array(self, flatten=False, dtype=None, array_backend=None):
def to_array(self, flatten=False, dtype=None, array_backend=None, index=None):
r"""Return the values stored in the field in the
format of :attr:`array_backend`.
Expand All @@ -163,6 +168,9 @@ def to_array(self, flatten=False, dtype=None, array_backend=None):
dtype: str, array.dtype or None
Typecode or data-type of the array. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is ``float64``.
index: array indexing object, optional
The index of the values and to be extracted. When it
is None all the values are extracted
Returns
-------
Expand All @@ -177,17 +185,21 @@ def to_array(self, flatten=False, dtype=None, array_backend=None):
)
shape = self._required_shape(flatten)
if shape != v.shape:
return self._array_backend.array_ns.reshape(v, shape)
v = self._array_backend.array_ns.reshape(v, shape)
if index is not None:
v = v[index]
return v

def _required_shape(self, flatten):
return self.shape if not flatten else (math.prod(self.shape),)
def _required_shape(self, flatten, shape=None):
if shape is None:
shape = self.shape
return shape if not flatten else (math.prod(shape),)

def _array_matches(self, array, flatten=False, dtype=None):
shape = self._required_shape(flatten)
return shape == array.shape and (dtype is None or dtype == array.dtype)

def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None, index=None):
r"""Return the values and/or the geographical coordinates for each grid point.
Parameters
Expand All @@ -201,6 +213,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
dtype: str, array.dtype or None
Typecode or data-type of the arrays. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is ``float64``.
index: array indexing object, optional
The index of the values and or the latitudes/longitudes to be extracted. When it
is None all the values and/or coordinates are extracted.
Returns
-------
Expand Down Expand Up @@ -252,18 +267,22 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
if k not in _keys:
raise ValueError(f"data: invalid argument: {k}")

r = [self._to_array(_keys[k][0](dtype=dtype), source_backend=_keys[k][1]) for k in keys]
shape = self._required_shape(flatten)
if shape != r[0].shape:
# r = [x.reshape(shape) for x in r]
r = [self._array_backend.array_ns.reshape(x, shape) for x in r]
r = []
for k in keys:
v = self._to_array(_keys[k][0](dtype=dtype), source_backend=_keys[k][1])
shape = self._required_shape(flatten)
if shape != v.shape:
v = self._array_backend.array_ns.reshape(v, shape)
if index is not None:
v = v[index]
r.append(v)

if len(r) == 1:
return r[0]
else:
return self._array_backend.array_ns.stack(r)

def to_points(self, flatten=False, dtype=None):
def to_points(self, flatten=False, dtype=None, index=None):
r"""Return the geographical coordinates in the data's original
Coordinate Reference System (CRS).
Expand All @@ -276,6 +295,9 @@ def to_points(self, flatten=False, dtype=None):
Typecode or data-type of the arrays. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is
``float64``.
index: array indexing object, optional
The index of the coordinates to be extracted. When it is None
all the values are extracted.
Returns
-------
Expand Down Expand Up @@ -303,14 +325,17 @@ def to_points(self, flatten=False, dtype=None):
if shape != x.shape:
x = self._array_backend.array_ns.reshape(x, shape)
y = self._array_backend.array_ns.reshape(y, shape)
if index is not None:
x = x[index]
y = y[index]
return dict(x=x, y=y)
elif self.projection().CARTOPY_CRS == "PlateCarree":
lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype)
lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype, index=index)
return dict(x=lon, y=lat)
else:
raise ValueError("to_points(): geographical coordinates in original CRS are not available")

def to_latlon(self, flatten=False, dtype=None):
def to_latlon(self, flatten=False, dtype=None, index=None):
r"""Return the latitudes/longitudes of all the gridpoints in the field.
Parameters
Expand All @@ -322,6 +347,9 @@ def to_latlon(self, flatten=False, dtype=None):
Typecode or data-type of the arrays. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is
``float64``.
index: array indexing object, optional
The index of the latitudes/longitudes to be extracted. When it is None
all the values are extracted.
Returns
-------
Expand All @@ -335,7 +363,7 @@ def to_latlon(self, flatten=False, dtype=None):
to_points
"""
lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype)
lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype, index=index)
return dict(lat=lat, lon=lon)

def grid_points(self):
Expand Down Expand Up @@ -869,7 +897,7 @@ def to_array(self, **kwargs):

@property
def values(self):
r"""array-likr: Get all the fields' values as a 2D array. It is formed as the array of
r"""array-like: Get all the fields' values as a 2D array. It is formed as the array of
:obj:`GribField.values <data.readers.grib.codes.GribField.values>` per field.
See Also
Expand All @@ -893,7 +921,13 @@ def values(self):
x = [f.values for f in self]
return self._array_backend.array_ns.stack(x)

def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
def data(
self,
keys=("lat", "lon", "value"),
flatten=False,
dtype=None,
index=None,
):
r"""Return the values and/or the geographical coordinates.
Only works when all the fields have the same grid geometry.
Expand All @@ -910,6 +944,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
Typecode or data-type of the arrays. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is
``float64``.
index: array indexing object, optional
The index of the values to be extracted from each field. When it is None all the
values are extracted.
Returns
-------
Expand Down Expand Up @@ -962,7 +999,7 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
keys = [keys]

if "lat" in keys or "lon" in keys:
latlon = self[0].to_latlon(flatten=flatten, dtype=dtype)
latlon = self[0].to_latlon(flatten=flatten, dtype=dtype, index=index)

r = []
for k in keys:
Expand All @@ -971,10 +1008,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
elif k == "lon":
r.append(latlon["lon"])
elif k == "value":
r.extend([f.to_array(flatten=flatten, dtype=dtype) for f in self])
r.extend([f.to_array(flatten=flatten, dtype=dtype, index=index) for f in self])
else:
raise ValueError(f"data: invalid argument: {k}")

return self._array_backend.array_ns.stack(r)

elif len(self) == 0:
Expand Down Expand Up @@ -1226,11 +1262,14 @@ def to_points(self, **kwargs):
else:
raise ValueError("Fields do not have the same grid geometry")

def to_latlon(self, **kwargs):
def to_latlon(self, index=None, **kwargs):
r"""Return the latitudes/longitudes shared by all the fields.
Parameters
----------
index: array indexing object, optional
The index of the latitudes/longitudes to be extracted. When it is None
all the values are extracted.
**kwargs: dict, optional
Keyword arguments passed to
:meth:`Field.to_latlon() <data.core.fieldlist.Field.to_latlon>`
Expand Down
Loading

0 comments on commit 9c76c22

Please sign in to comment.