diff --git a/rex/resource_extraction/resource_extraction.py b/rex/resource_extraction/resource_extraction.py index d28939e9a..62d4c46dd 100644 --- a/rex/resource_extraction/resource_extraction.py +++ b/rex/resource_extraction/resource_extraction.py @@ -427,7 +427,7 @@ def _save_tree(tree, tree_path): .format(tree_path, e)) @staticmethod - def _get_ds_slice(dset, gids): + def _get_ds_slice(dset, gids, ds_ndim): """ Get dataset region slice @@ -437,18 +437,25 @@ def _get_ds_slice(dset, gids): Dataset to extract region from gids : ndarray | list Gids associated with region + ds_ndim : int + Number of dimensions for dset. 1D is assumed to be spatial, 2D is + assumed to be (temporal, spatial). Returns ------- ds_slice : tuple ds slice tuple to properly extract region from given dataset """ - if dset == 'time_index': + if dset.startswith('time_index'): ds_slice = (slice(None), ) elif dset in ['coordinates', 'meta']: ds_slice = (gids, ) - else: + elif ds_ndim == 1: + ds_slice = (gids,) + elif ds_ndim == 2: ds_slice = (slice(None), gids) + else: + ds_slice = (slice(None), gids) + (slice(None),) * (ds_ndim - 2) return ds_slice @@ -1449,7 +1456,7 @@ def save_subset(self, out_fpath, gids, datasets=None): for dset in datasets: if dset in self: ds = self.h5[dset] - ds_slice = self._get_ds_slice(dset, gids) + ds_slice = self._get_ds_slice(dset, gids, ds.ndim) data = ResourceDataset.extract(ds, ds_slice, scale_attr=scale_attr, add_attr=add_attr, diff --git a/tests/test_resource_extraction.py b/tests/test_resource_extraction.py index 94a171119..b1f8c6ae3 100644 --- a/tests/test_resource_extraction.py +++ b/tests/test_resource_extraction.py @@ -1143,6 +1143,23 @@ def test_resourcex_iterable(NSRDBX_cls): assert len(dsets_permutation) == num_dsets ** 2 +def test_save_subset(NSRDBX_cls): + """ + Run test saving a subset of a file to a new file with + ``ResourceX.save_subset()`` + """ + gids = [10, 12, 13, 14, 20] + with tempfile.TemporaryDirectory() as td: + fp_out = os.path.join(td, 'test.h5') + NSRDBX_cls.save_subset(fp_out, gids, datasets=NSRDBX_cls.dsets) + + with NSRDBX(fp_out) as out: + assert_frame_equal(NSRDBX_cls['meta', gids].reset_index(drop=True), + out['meta'].reset_index(drop=True)) + assert (NSRDBX_cls.time_index == out.time_index).all() + assert np.allclose(NSRDBX_cls['ghi', :, gids], out['ghi']) + + def execute_pytest(capture='all', flags='-rapP'): """Execute module as pytest with detailed summary report.