Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions rex/resource_extraction/resource_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def _save_tree(tree, tree_path):
.format(tree_path, e))

@staticmethod
def _get_ds_slice(dset, gids):
def _get_ds_slice(dset, gids, ds_ndim):
"""
Get dataset region slice

Expand All @@ -437,18 +437,25 @@ def _get_ds_slice(dset, gids):
Dataset to extract region from
gids : ndarray | list
Gids associated with region
ds_ndim : int
Number of dimensions for dset. 1D is assumed to be spatial, 2D is
assumed to be (temporal, spatial).

Returns
-------
ds_slice : tuple
ds slice tuple to properly extract region from given dataset
"""
if dset == 'time_index':
if dset.startswith('time_index'):
ds_slice = (slice(None), )
elif dset in ['coordinates', 'meta']:
ds_slice = (gids, )
else:
elif ds_ndim == 1:
ds_slice = (gids,)
elif ds_ndim == 2:
ds_slice = (slice(None), gids)
else:
ds_slice = (slice(None), gids) + (slice(None),) * (ds_ndim - 2)

return ds_slice

Expand Down Expand Up @@ -1449,7 +1456,7 @@ def save_subset(self, out_fpath, gids, datasets=None):
for dset in datasets:
if dset in self:
ds = self.h5[dset]
ds_slice = self._get_ds_slice(dset, gids)
ds_slice = self._get_ds_slice(dset, gids, ds.ndim)
data = ResourceDataset.extract(ds, ds_slice,
scale_attr=scale_attr,
add_attr=add_attr,
Expand Down
17 changes: 17 additions & 0 deletions tests/test_resource_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,23 @@ def test_resourcex_iterable(NSRDBX_cls):
assert len(dsets_permutation) == num_dsets ** 2


def test_save_subset(NSRDBX_cls):
"""
Run test saving a subset of a file to a new file with
``ResourceX.save_subset()``
"""
gids = [10, 12, 13, 14, 20]
with tempfile.TemporaryDirectory() as td:
fp_out = os.path.join(td, 'test.h5')
NSRDBX_cls.save_subset(fp_out, gids, datasets=NSRDBX_cls.dsets)

with NSRDBX(fp_out) as out:
assert_frame_equal(NSRDBX_cls['meta', gids].reset_index(drop=True),
out['meta'].reset_index(drop=True))
assert (NSRDBX_cls.time_index == out.time_index).all()
assert np.allclose(NSRDBX_cls['ghi', :, gids], out['ghi'])


def execute_pytest(capture='all', flags='-rapP'):
"""Execute module as pytest with detailed summary report.

Expand Down