Skip to content

Commit

Permalink
Merge pull request #1151 from UXARRAY/philipc2/subset-quickfix
Browse files Browse the repository at this point in the history
Improve performance of Subset & Slicing
  • Loading branch information
aaronzedwick authored Feb 5, 2025
2 parents 9926057 + d519300 commit fe4cae1
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
6 changes: 6 additions & 0 deletions test/test_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,20 @@ def test_grid_face_isel():
for grid_path in GRID_PATHS:
grid = ux.open_grid(grid_path)

grid_contains_edge_node_conn = "edge_node_connectivity" in grid._ds

face_indices = [0, 1, 2, 3, 4]
for n_max_faces in range(1, len(face_indices)):
grid_subset = grid.isel(n_face=face_indices[:n_max_faces])
assert grid_subset.n_face == n_max_faces
if not grid_contains_edge_node_conn:
assert "edge_node_connectivity" not in grid_subset._ds

face_indices = [0, 1, 2, grid.n_face]
with pytest.raises(IndexError):
grid_subset = grid.isel(n_face=face_indices)
if not grid_contains_edge_node_conn:
assert "edge_node_connectivity" not in grid_subset._ds


def test_grid_node_isel():
Expand Down
22 changes: 15 additions & 7 deletions uxarray/grid/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,23 @@ def _slice_face_indices(
node_indices = np.unique(grid.face_node_connectivity.values[face_indices].ravel())
node_indices = node_indices[node_indices != INT_FILL_VALUE]

# edges of each face (inclusive)
edge_indices = np.unique(grid.face_edge_connectivity.values[face_indices].ravel())
edge_indices = edge_indices[edge_indices != INT_FILL_VALUE]

# index original dataset to obtain a 'subgrid'
ds = ds.isel(n_node=node_indices)
ds = ds.isel(n_face=face_indices)
ds = ds.isel(n_edge=edge_indices)

# Only slice edge dimension if we already have the connectivity
if "face_edge_connectivity" in grid._ds:
edge_indices = np.unique(
grid.face_edge_connectivity.values[face_indices].ravel()
)
edge_indices = edge_indices[edge_indices != INT_FILL_VALUE]
ds = ds.isel(n_edge=edge_indices)
ds["subgrid_edge_indices"] = xr.DataArray(edge_indices, dims=["n_edge"])
else:
edge_indices = None

ds["subgrid_node_indices"] = xr.DataArray(node_indices, dims=["n_node"])
ds["subgrid_face_indices"] = xr.DataArray(face_indices, dims=["n_face"])
ds["subgrid_edge_indices"] = xr.DataArray(edge_indices, dims=["n_edge"])

# mapping to update existing connectivity
node_indices_dict = {
Expand Down Expand Up @@ -152,9 +157,12 @@ def _slice_face_indices(

index_types = {
"face": face_indices,
"edge": edge_indices,
"node": node_indices,
}

if edge_indices is not None:
index_types["edge"] = edge_indices

if isinstance(inverse_indices, bool):
inverse_indices_ds["face"] = face_indices
else:
Expand Down
4 changes: 2 additions & 2 deletions uxarray/subset/dataarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def bounding_circle(
self,
center_coord: Union[Tuple, List, np.ndarray],
r: Union[float, int],
element: Optional[str] = "nodes",
element: Optional[str] = "face centers",
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
Expand Down Expand Up @@ -97,7 +97,7 @@ def nearest_neighbor(
self,
center_coord: Union[Tuple, List, np.ndarray],
k: int,
element: Optional[str] = "nodes",
element: Optional[str] = "face centers",
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions uxarray/subset/grid_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def bounding_circle(
self,
center_coord: Union[Tuple, List, np.ndarray],
r: Union[float, int],
element: Optional[str] = "nodes",
element: Optional[str] = "face centers",
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
Expand Down Expand Up @@ -108,7 +108,7 @@ def nearest_neighbor(
self,
center_coord: Union[Tuple, List, np.ndarray],
k: int,
element: Optional[str] = "nodes",
element: Optional[str] = "face centers",
inverse_indices: Union[List[str], Set[str], bool] = False,
**kwargs,
):
Expand Down

0 comments on commit fe4cae1

Please sign in to comment.