From 9a54d4ff727e6ba84f6354a734c00371a7dcdd17 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 6 Feb 2024 13:29:54 +0000 Subject: [PATCH] update --- ecml_tools/data.py | 85 ++++++++++++++---------------------------- ecml_tools/indexing.py | 62 +++++++++++++++++++++++------- tests/test_data.py | 10 +++-- 3 files changed, 81 insertions(+), 76 deletions(-) diff --git a/ecml_tools/data.py b/ecml_tools/data.py index ee83922..0dc8f03 100644 --- a/ecml_tools/data.py +++ b/ecml_tools/data.py @@ -11,7 +11,7 @@ import os import re import warnings -from functools import cached_property +from functools import cached_property, wraps from pathlib import PurePath import numpy as np @@ -22,6 +22,7 @@ from .indexing import ( apply_index_to_slices_changes, + expand_list_indexing, index_to_slices, length_to_slices, update_tuple, @@ -37,6 +38,7 @@ def _debug_indexing(method): + @wraps(method) def wrapper(self, index): global DEPTH if isinstance(index, tuple): @@ -224,6 +226,7 @@ def __repr__(self): return self.__class__.__name__ + "()" @debug_indexing + @expand_list_indexing def _get_tuple(self, n): raise NotImplementedError( f"Tuple not supported: {n} (class {self.__class__.__name__})" @@ -381,30 +384,10 @@ def __len__(self): return self.data.shape[0] @debug_indexing + @expand_list_indexing def __getitem__(self, n): - if isinstance(n, tuple) and any(not isinstance(i, (int, slice)) for i in n): - return self._getitem_extended(n) - return self.data[n] - def _getitem_extended(self, index): - """ - Allows to use slices, lists, and tuples to select data from the dataset. - Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves. - """ - - assert False, index - - shape = self.data.shape - - axes = [] - data = [] - for n in self._unwind(index[0], index[1:], shape, 0, axes): - data.append(self.data[n]) - - assert len(axes) == 1, axes # Not implemented for more than one axis - return np.concatenate(data, axis=axes[0]) - def _unwind(self, index, rest, shape, axis, axes): if not isinstance(index, (int, slice, list, tuple)): try: @@ -676,28 +659,20 @@ def __len__(self): return sum(len(i) for i in self.datasets) @debug_indexing + @expand_list_indexing def _get_tuple(self, index): index, changes = index_to_slices(index, self.shape) - result = [] - - first, rest = index[0], index[1:] - start, stop, step = first.start, first.stop, first.step - - for d in self.datasets: - length = d._len - - result.append(d[(slice(start, stop, step),) + rest]) - - start -= length - while start < 0: - start += step - - stop -= length - - if start > stop: - break - - return apply_index_to_slices_changes(np.concatenate(result, axis=0), changes) + print(index, changes) + lengths = [d.shape[0] for d in self.datasets] + slices = length_to_slices(index[0], lengths) + print("slies", slices) + result = [ + d[update_tuple(index, 0, i)[0]] + for (d, i) in zip(self.datasets, slices) + if i is not None + ] + result = np.concatenate(result, axis=0) + return apply_index_to_slices_changes(result, changes) @debug_indexing def __getitem__(self, n): @@ -718,21 +693,10 @@ def __getitem__(self, n): def _get_slice(self, s): result = [] - start, stop, step = s.indices(self._len) - - for d in self.datasets: - length = d._len - - result.append(d[start:stop:step]) - - start -= length - while start < 0: - start += step - - stop -= length + lengths = [d.shape[0] for d in self.datasets] + slices = length_to_slices(s, lengths) - if start > stop: - break + result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None] return np.concatenate(result) @@ -783,13 +747,15 @@ def shape(self): return result @debug_indexing + @expand_list_indexing def _get_tuple(self, index): index, changes = index_to_slices(index, self.shape) lengths = [d.shape[self.axis] for d in self.datasets] slices = length_to_slices(index[self.axis], lengths) - before = index[: self.axis] result = [ - d[before + (i,)] for (d, i) in zip(self.datasets, slices) if i is not None + d[update_tuple(index, self.axis, i)[0]] + for (d, i) in zip(self.datasets, slices) + if i is not None ] result = np.concatenate(result, axis=self.axis) return apply_index_to_slices_changes(result, changes) @@ -850,6 +816,7 @@ def __len__(self): return len(self.datasets[0]) @debug_indexing + @expand_list_indexing def _get_tuple(self, index): index, changes = index_to_slices(index, self.shape) index, previous = update_tuple(index, 1, slice(None)) @@ -977,6 +944,7 @@ def _get_slice(self, s): return np.stack([self.dataset[i] for i in indices]) @debug_indexing + @expand_list_indexing def _get_tuple(self, n): index, changes = index_to_slices(n, self.shape) index, previous = update_tuple(index, 0, self.slice) @@ -1024,6 +992,7 @@ def __init__(self, dataset, indices): super().__init__(dataset) @debug_indexing + @expand_list_indexing def _get_tuple(self, index): index, changes = index_to_slices(index, self.shape) index, previous = update_tuple(index, 1, slice(None)) diff --git a/ecml_tools/indexing.py b/ecml_tools/indexing.py index 7f95a2b..e3557a6 100644 --- a/ecml_tools/indexing.py +++ b/ecml_tools/indexing.py @@ -6,6 +6,8 @@ # nor does it submit to any jurisdiction. +from functools import wraps + import numpy as np @@ -109,21 +111,53 @@ def length_to_slices(index, lengths): return result -class IndexTester: - def __init__(self, shape): - self.shape = shape +def _as_tuples(index): + def _(i): + if hasattr(i, "tolist"): + # NumPy arrays, TensorFlow tensors, etc. + i = i.tolist() + assert not isinstance(i[0], bool), "Mask not supported" + return tuple(i) + + if isinstance(i, list): + return tuple(i) + + return i + + return tuple(_(i) for i in index) + + +def expand_list_indexing(method): + """ + Allows to use slices, lists, and tuples to select data from the dataset. + Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves. + """ + + @wraps(method) + def wrapper(self, index): + if not isinstance(index, tuple): + return method(self, index) + + if not any(isinstance(i, (list, tuple)) for i in index): + return method(self, index) + + which = [] + for i, idx in enumerate(index): + if isinstance(idx, (list, tuple)): + which.append(i) + + assert which, "No list index found" - def __getitem__(self, index): - return index_to_slices(index, self.shape) + if len(which) > 1: + raise IndexError("Only one list index is allowed") + which = which[0] + index = _as_tuples(index) + result = [] + for i in index[which]: + index, _ = update_tuple(index, which, slice(i, i + 1)) + result.append(method(self, index)) -if __name__ == "__main__": - t = IndexTester((1000, 8, 10, 20000)) - i = t[0, 1, 2, 3] - print(i) + return np.concatenate(result, axis=which) - # print(t[0]) - # print(t[0, 1, 2, 3]) - # print(t[0:10]) - # print(t[...]) - # print(t[:-1]) + return wrapper diff --git a/tests/test_data.py b/tests/test_data.py index cd27228..0e6e1f0 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -201,16 +201,16 @@ def indexing(ds): t[0:10, 0:3, 0] t[:, :, :] - # t[:, (1, 3), :] - # t[:, (1, 3)] + t[:, (1, 3), :] + t[:, (1, 3)] t[0] t[0, :] t[0, 0, :] t[0, 0, 0, :] - # if ds.shape[2] > 1: # Ensemble dimension - # t[0:10, :, (0, 1)] + if ds.shape[2] > 1: # Ensemble dimension + t[0:10, :, (0, 1)] def slices(ds, start=None, end=None, step=None): @@ -1134,6 +1134,8 @@ def test_ensemble_1(): ) ds = test.ds + ds[0:10,:,(1,2)] + assert isinstance(ds, Ensemble) assert len(ds) == 365 * 1 * 4 assert len([row for row in ds]) == len(ds)