From cad646a7600b48f32f343b006edf3b18224cc46e Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 5 Feb 2024 18:56:44 +0000 Subject: [PATCH] Work on indexing --- ecml_tools/data.py | 68 +++++++++++++++++++++++++++------- ecml_tools/indexing.py | 84 ++++++++++++++++++++++++++++++++++++++++++ tests/test_data.py | 72 ++++++++++++++++++++++++++---------- 3 files changed, 192 insertions(+), 32 deletions(-) create mode 100644 ecml_tools/indexing.py diff --git a/ecml_tools/data.py b/ecml_tools/data.py index dafa4a7..1fcaa61 100644 --- a/ecml_tools/data.py +++ b/ecml_tools/data.py @@ -20,6 +20,8 @@ import ecml_tools +from .indexing import apply_index_to_slices_changes, index_to_slices + LOG = logging.getLogger(__name__) __all__ = ["open_dataset", "open_zarr", "debug_zarr_loading"] @@ -55,11 +57,20 @@ def _tuple_with_slices(t): """ result = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in t) - changes = [j for (j, i) in enumerate(t) if isinstance(i, int)] + changes = tuple(j for (j, i) in enumerate(t) if isinstance(i, int)) return result, changes +def _apply_tuple_changes(result, changes): + if changes: + shape = result.shape + for i in changes: + assert shape[i] == 1, shape + result = np.squeeze(result, axis=changes) + return result + + class Dataset: arguments = {} @@ -367,7 +378,6 @@ def _getitem_extended(self, index): assert False, index - shape = self.data.shape axes = [] @@ -648,6 +658,29 @@ class Concat(Combined): def __len__(self): return sum(len(i) for i in self.datasets) + def _get_tuple(self, index): + index, changes = index_to_slices(index, self.shape) + result = [] + + first, rest = index[0], index[1:] + start, stop, step = first.start, first.stop, first.step + + for d in self.datasets: + length = d._len + + result.append(d[(slice(start, stop, step),) + rest]) + + start -= length + while start < 0: + start += step + + stop -= length + + if start > stop: + break + + return apply_index_to_slices_changes(np.concatenate(result, axis=0), changes) + def __getitem__(self, n): if isinstance(n, tuple): return self._get_tuple(n) @@ -783,28 +816,37 @@ def __len__(self): return len(self.datasets[0]) def _get_tuple(self, index): + print("Join._get_tuple", index) assert len(index) > 1, index - selected_variables = index[1] - index, changed = _tuple_with_slices(index) + selected_variables = index[1] + index = list(index) index[1] = slice(None) index = tuple(index) + print("Join._get_tuple", index) + # TODO: optimize if index does not access all datasets, so we don't load chunks we don't need result = [d[index] for d in self.datasets] - print(self.shape, [r.shape for r in result], selected_variables, changed) - result = np.stack(result) - print(result.shape) + print( + "Join._get_tuple", + self.shape, + [r.shape for r in result], + selected_variables, + changed, + ) + result = np.concatenate(result, axis=1) + print("Join._get_tuple", result.shape) - raise NotImplementedError() + # raise NotImplementedError() - result = np.concatenate(result) + # result = np.concatenate(result) # result = np.stack(result) - return result[index[1]] + return _apply_tuple_changes(result[:, selected_variables], changed) def _get_slice(self, s): return np.stack([self[i] for i in range(*s.indices(self._len))]) @@ -967,11 +1009,11 @@ def __init__(self, dataset, indices): super().__init__(dataset) def __getitem__(self, n): - # if isinstance(n, tuple): - # return self._get_tuple(n) + if isinstance(n, tuple): + return self._get_tuple(n) row = self.dataset[n] - if isinstance(n, (slice, tuple)): + if isinstance(n, slice): return row[:, self.indices] return row[self.indices] diff --git a/ecml_tools/indexing.py b/ecml_tools/indexing.py new file mode 100644 index 0000000..8f61abd --- /dev/null +++ b/ecml_tools/indexing.py @@ -0,0 +1,84 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import numpy as np + + +def _tuple_with_slices(t, shape): + """ + Replace all integers in a tuple with slices, so we preserve the dimensionality. + """ + + result = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in t) + changes = tuple(j for (j, i) in enumerate(t) if isinstance(i, int)) + result = tuple(slice(*s.indices(shape[i])) for (i, s) in enumerate(result)) + + return result, changes + + +def _extend_shape(index, shape): + if Ellipsis in index: + if index.count(Ellipsis) > 1: + raise IndexError("Only one Ellipsis is allowed") + ellipsis_index = index.index(Ellipsis) + index = list(index) + index[ellipsis_index] = slice(None) + while len(index) < len(shape): + index.insert(ellipsis_index, slice(None)) + index = tuple(index) + + while len(index) < len(shape): + index = index + (slice(None),) + + return index + + +def _index_to_tuple(index, shape): + if isinstance(index, int): + return _extend_shape((index,), shape) + if isinstance(index, slice): + return _extend_shape((index,), shape) + if isinstance(index, tuple): + return _extend_shape(index, shape) + if index is Ellipsis: + return _extend_shape((Ellipsis,), shape) + raise ValueError(f"Invalid index: {index}") + + +def index_to_slices(index, shape): + """ + Convert an index to a tuple of slices, with the same dimensionality as the shape. + """ + return _tuple_with_slices(_index_to_tuple(index, shape), shape) + + +def apply_index_to_slices_changes(result, changes): + if changes: + shape = result.shape + for i in changes: + assert shape[i] == 1, shape + result = np.squeeze(result, axis=changes) + return result + + +class IndexTester: + def __init__(self, shape): + self.shape = shape + + def __getitem__(self, index): + return index_to_slices(index, self.shape) + + +if __name__ == "__main__": + t = IndexTester((1000, 8, 10, 20000)) + + print(t[0]) + print(t[0, 1, 2, 3]) + print(t[0:10]) + print(t[...]) + print(t[:-1]) diff --git a/tests/test_data.py b/tests/test_data.py index 692b8a5..61e12c7 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -184,22 +184,9 @@ def __getitem__(self, index): assert (self.ds[index] == self.np[index]).all() -def slices(ds, start=None, end=None, step=None): - if start is None: - start = 5 - if end is None: - end = len(ds) - 5 - if step is None: - step = len(ds) // 10 - +def indexing(ds): t = IndexTester(ds) - t[start:end:step] - t[start:end] - t[start:] - t[:end] - t[::step] - t[0:10, :, 0] t[:, 0:3, 0] t[:, :, 0] @@ -218,6 +205,23 @@ def slices(ds, start=None, end=None, step=None): # t[0:10, :, (0, 1)] +def slices(ds, start=None, end=None, step=None): + if start is None: + start = 5 + if end is None: + end = len(ds) - 5 + if step is None: + step = len(ds) // 10 + + t = IndexTester(ds) + + t[start:end:step] + t[start:end] + t[start:] + t[:end] + t[::step] + + def make_row(args, ensemble=False, grid=False): if grid: @@ -266,6 +270,7 @@ def test_simple(): same_stats(ds, open_dataset("test-2021-2022-6h-o96-abcd"), "abcd") slices(ds) + indexing(ds) metadata(ds) @@ -298,13 +303,14 @@ def test_concat(): same_stats(ds, open_dataset("test-2021-2022-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) def test_join_1(): ds = open_dataset( "test-2021-2021-6h-o96-abcd", - "test-2021-2021-6h-o96-efgh", + "test-2021-2021-6h-o96-efg", ) assert isinstance(ds, Join) @@ -324,7 +330,6 @@ def test_join_1(): _(date, "e"), _(date, "f"), _(date, "g"), - _(date, "h"), ] ) assert (row == expect).all() @@ -333,7 +338,7 @@ def test_join_1(): assert (ds.dates == np.array(dates, dtype="datetime64")).all() - assert ds.variables == ["a", "b", "c", "d", "e", "f", "g", "h"] + assert ds.variables == ["a", "b", "c", "d", "e", "f", "g"] assert ds.name_to_index == { "a": 0, "b": 1, @@ -342,13 +347,13 @@ def test_join_1(): "e": 4, "f": 5, "g": 6, - "h": 7, } - assert ds.shape == (365 * 4, 8, 1, VALUES) + assert ds.shape == (365 * 4, 7, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -396,6 +401,7 @@ def test_join_2(): "abcdef", ) slices(ds) + indexing metadata(ds) @@ -435,6 +441,7 @@ def test_join_3(): assert ds.shape == (365 * 4, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd-2"), "abcd") slices(ds) + indexing metadata(ds) @@ -469,6 +476,7 @@ def test_subset_1(): assert ds.shape == (365 * 3 * 2, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2023-1h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -504,6 +512,7 @@ def test_subset_2(): same_stats(ds, open_dataset("test-2021-2023-1h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -544,6 +553,7 @@ def test_subset_3(): assert ds.shape == (365 * 2, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2023-1h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -579,6 +589,7 @@ def test_subset_4(): same_stats(ds, open_dataset("test-2021-2023-1h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -614,6 +625,7 @@ def test_subset_5(): same_stats(ds, open_dataset("test-2021-2023-1h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -653,6 +665,7 @@ def test_subset_6(): same_stats(ds, open_dataset("test-2021-2023-1h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -688,6 +701,7 @@ def test_subset_7(): same_stats(ds, open_dataset("test-2021-2023-1h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -730,6 +744,7 @@ def test_subset_8(): same_stats(ds, open_dataset("test-2021-2021-1h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -757,6 +772,7 @@ def test_select_1(): assert ds.shape == (365 * 4, 2, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "bd") slices(ds) + indexing metadata(ds) @@ -784,6 +800,7 @@ def test_select_2(): same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "ac") slices(ds) + indexing metadata(ds) @@ -811,6 +828,7 @@ def test_select_3(): assert ds.shape == (365 * 4, 2, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "ac") slices(ds) + indexing metadata(ds) @@ -838,6 +856,7 @@ def test_rename(): assert ds.shape == (365 * 4, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "xbyd", "abcd") slices(ds) + indexing metadata(ds) @@ -864,6 +883,7 @@ def test_drop(): assert ds.shape == (365 * 4, 3, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "bcd") slices(ds) + indexing metadata(ds) @@ -891,6 +911,7 @@ def test_reorder_1(): assert ds.shape == (365 * 4, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -918,6 +939,7 @@ def test_reorder_2(): assert ds.shape == (365 * 4, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -949,6 +971,7 @@ def test_constructor_1(): assert ds.shape == (365 * 2 * 4, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -978,6 +1001,7 @@ def test_constructor_2(): assert ds.shape == (365 * 2 * 4, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -1007,6 +1031,7 @@ def test_constructor_3(): assert ds.shape == (365 * 2 * 4, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -1037,6 +1062,7 @@ def test_constructor_4(): assert ds.shape == (365 * 2 * 4, 4, 1, VALUES) same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -1080,6 +1106,7 @@ def test_constructor_5(): same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd-1"), "xyd", "acd") same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd-2"), "abzt", "abcd") slices(ds) + indexing metadata(ds) @@ -1160,6 +1187,7 @@ def test_slice_6(): ds = open_dataset([f"test-{year}-{year}-1h-o96-abcd" for year in range(1940, 2023)]) slices(ds) + indexing metadata(ds) slices(ds, 0, len(ds), 1) slices(ds, 0, len(ds), 10) @@ -1181,6 +1209,7 @@ def test_slice_7(): ) slices(ds) + indexing metadata(ds) slices(ds, 0, len(ds), 1) slices(ds, 0, len(ds), 10) @@ -1199,6 +1228,7 @@ def test_slice_8(): ) slices(ds) + indexing metadata(ds) slices(ds, 0, len(ds), 1) slices(ds, 0, len(ds), 10) @@ -1218,6 +1248,7 @@ def test_slice_9(): ) slices(ds) + indexing metadata(ds) slices(ds, 0, len(ds), 1) slices(ds, 0, len(ds), 10) @@ -1268,6 +1299,7 @@ def test_ensemble_1(): assert ds.shape == (365 * 4, 4, 11, VALUES) # same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -1317,6 +1349,7 @@ def test_ensemble_2(): assert ds.shape == (365 * 4, 4, 16, VALUES) # same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) @@ -1371,6 +1404,7 @@ def test_grids(): assert ds.shape == (365 * 4, 4, 1, VALUES + 25) # same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd") slices(ds) + indexing metadata(ds) ds1 = open_dataset("test-2021-2021-6h-o96-abcd-1-1")