Skip to content

Commit 90f2552

Browse files
authored
Allow non-numeric coordinates (e.g. time) as spatial axes (#65)
* fix: slicing should use indices, not coordinate values * fix: use indices for mesh coordinates if coord dtype is not numeric * feat: Add scales arg to specify scale of an index array made for non-numeric coords * style: reformat with black * fix: protect scales value when slicing is None * fix: Modify slicing behavior * test: Add two test cases for `test_source`
1 parent dceabc7 commit 90f2552

File tree

5 files changed

+78
-36
lines changed

5 files changed

+78
-36
lines changed

pvxarray/accessor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Optional
1+
from typing import Dict, Optional
22

3+
import numpy as np
34
import pyvista as pv
45
import xarray as xr
56

@@ -38,9 +39,13 @@ def loc(self) -> _LocIndexer:
3839
"""Attribute for location based indexing like pandas."""
3940
return _LocIndexer(self)
4041

41-
def _get_array(self, key):
42+
def _get_array(self, key, scale=1):
4243
try:
43-
return self._obj[key].values
44+
values = self._obj[key].values
45+
if "float" not in str(values.dtype) and "int" not in str(values.dtype):
46+
# non-numeric coordinate, assign array of scaled indices
47+
values = np.array(range(len(values))) * scale
48+
return values
4449
except KeyError:
4550
raise KeyError(
4651
f"Key {key} not present in DataArray. Choices are: {list(self._obj.coords.keys())}"
@@ -58,6 +63,7 @@ def mesh(
5863
order: Optional[str] = None,
5964
component: Optional[str] = None,
6065
mesh_type: Optional[str] = None,
66+
scales: Optional[Dict] = None,
6167
) -> pv.DataSet:
6268
ndim = 0
6369
if x is not None:
@@ -80,7 +86,7 @@ def mesh(
8086
meth = methods[mesh_type]
8187
except KeyError:
8288
raise KeyError
83-
return meth(self, x=x, y=y, z=z, order=order, component=component)
89+
return meth(self, x=x, y=y, z=z, order=order, component=component, scales=scales)
8490

8591
def plot(
8692
self,

pvxarray/rectilinear.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Dict, Optional
22
import warnings
33

44
import pyvista as pv
@@ -13,6 +13,7 @@ def mesh(
1313
z: Optional[str] = None,
1414
order: Optional[str] = "C",
1515
component: Optional[str] = None,
16+
scales: Optional[Dict] = None,
1617
):
1718
if order is None:
1819
order = "C"
@@ -22,11 +23,11 @@ def mesh(
2223
raise ValueError("You must specify at least one dimension as X, Y, or Z.")
2324
# Construct the mesh
2425
if x is not None:
25-
self._mesh.x = self._get_array(x)
26+
self._mesh.x = self._get_array(x, scale=(scales and scales.get(x)) or 1)
2627
if y is not None:
27-
self._mesh.y = self._get_array(y)
28+
self._mesh.y = self._get_array(y, scale=(scales and scales.get(y)) or 1)
2829
if z is not None:
29-
self._mesh.z = self._get_array(z)
30+
self._mesh.z = self._get_array(z, scale=(scales and scales.get(z)) or 1)
3031
# Handle data values
3132
values = self.data
3233
values_dim = values.ndim

pvxarray/structured.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Dict, Optional
22
import warnings
33

44
import numpy as np
@@ -41,6 +41,7 @@ def _points(
4141
y: Optional[str] = None,
4242
z: Optional[str] = None,
4343
order: Optional[str] = "F",
44+
scales: Optional[Dict] = None,
4445
):
4546
"""Generate structured points as new array."""
4647
if order is None:
@@ -52,11 +53,11 @@ def _points(
5253
raise ValueError("One dimensional structured grids should be rectilinear grids.")
5354
raise ValueError("You must specify at least two dimensions as X, Y, or Z.")
5455
if x is not None:
55-
x = self._get_array(x)
56+
x = self._get_array(x, scale=(scales and scales.get(x)) or 1)
5657
if y is not None:
57-
y = self._get_array(y)
58+
y = self._get_array(y, scale=(scales and scales.get(y)) or 1)
5859
if z is not None:
59-
z = self._get_array(z)
60+
z = self._get_array(z, scale=(scales and scales.get(z)) or 1)
6061
arrs = _coerce_shapes(x, y, z)
6162
x, y, z = arrs
6263
arr = [a for a in arrs if a is not None][0]
@@ -78,6 +79,7 @@ def mesh(
7879
z: Optional[str] = None,
7980
order: str = "F",
8081
component: Optional[str] = None, # TODO
82+
scales: Optional[Dict] = None,
8183
):
8284
if order is None:
8385
order = "F"
@@ -88,7 +90,7 @@ def mesh(
8890
"StructuredGrid accessor duplicates data - VTK/PyVista data not shared with xarray."
8991
)
9092
)
91-
points, shape = _points(self, x=x, y=y, z=z, order=order)
93+
points, shape = _points(self, x=x, y=y, z=z, order=order, scales=scales)
9294
self._mesh.points = points
9395
self._mesh.dimensions = shape
9496
data = self.data

pvxarray/vtk_source.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -228,30 +228,21 @@ def _compute_sliced_data_array(self):
228228
self._sliced_data_array = None
229229
return None
230230

231+
indexing = {}
232+
if self._slicing is not None:
233+
indexing = {
234+
k: slice(*v) for k, v in self._slicing.items() if k in [self.x, self.y, self.z]
235+
}
236+
231237
if self._time is not None:
232-
da = self.data_array[{self._time: self.time_index}]
233-
else:
234-
da = self.data_array
235-
236-
if self._z and self._z_index is not None:
237-
da = da[{self._z: self.z_index}]
238-
239-
if self._slicing:
240-
indexing = {}
241-
for axis in [
242-
self.x,
243-
self.y,
244-
self.z,
245-
]:
246-
if axis in self._slicing:
247-
s = self._slicing[axis]
248-
c = da.coords[axis]
249-
sliced_array = np.where(np.logical_and(c >= s[0], c <= s[1]))[0]
250-
sliced_array = sliced_array[:: int(s[2])]
251-
indexing[axis] = sliced_array
252-
da = da.isel(**indexing)
253-
254-
elif self._resolution:
238+
indexing.update(**{self._time: self.time_index})
239+
240+
if self.z and self.z_index is not None:
241+
indexing.update(**{self.z: self.z_index})
242+
243+
da = self.data_array.isel(indexing)
244+
245+
if self._slicing is None and self._resolution is not None:
255246
rx, ry, rz = self.resolution_to_sampling_rate(da)
256247
if da.ndim <= 1:
257248
da = da[::rx]
@@ -271,6 +262,7 @@ def _compute_mesh(self):
271262
order=self._order,
272263
component=self._component,
273264
mesh_type=self._mesh_type,
265+
scales={k: v[2] for k, v in self._slicing.items()} if self._slicing else {},
274266
)
275267
return self._mesh
276268

tests/test_source.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,44 @@ def test_vtk_source():
2828
source.resolution = 0.5
2929
mesh = source.apply()
3030
assert mesh.n_points < 1325
31+
32+
33+
def test_vtk_source_time_as_spatial():
34+
ds = xr.tutorial.load_dataset("air_temperature")
35+
36+
da = ds.air
37+
source = PyVistaXarraySource(da, x="lon", y="lat", z="time")
38+
39+
mesh = source.apply()
40+
assert mesh
41+
assert mesh.n_points == 3869000
42+
assert "air" in mesh.point_data
43+
44+
assert np.array_equal(mesh["air"], da.values.ravel())
45+
assert np.array_equal(mesh.x, da.lon)
46+
assert np.array_equal(mesh.y, da.lat)
47+
# Z values are indexes instead of datetime objects
48+
assert np.array_equal(mesh.z, list(range(da.time.size)))
49+
50+
51+
def test_vtk_source_slicing():
52+
ds = xr.tutorial.load_dataset("eraint_uvz")
53+
54+
da = ds.z
55+
source = PyVistaXarraySource(
56+
da,
57+
x="longitude",
58+
y="latitude",
59+
z="level",
60+
time="month",
61+
)
62+
source.time_index = 1
63+
source.slicing = {
64+
"latitude": [0, 241, 2],
65+
"longitude": [0, 480, 4],
66+
"level": [0, 3, 1],
67+
"month": [0, 2, 1], # should be ignored in favor of t_index
68+
}
69+
70+
sliced = source.sliced_data_array
71+
assert sliced.shape == (3, 121, 120)

0 commit comments

Comments
 (0)