Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
better indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 6, 2024
1 parent 453766f commit 0fd9c4f
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 67 deletions.
115 changes: 70 additions & 45 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,36 @@

import ecml_tools

from .indexing import apply_index_to_slices_changes, index_to_slices, length_to_slices
from .indexing import (
apply_index_to_slices_changes,
index_to_slices,
length_to_slices,
update_tuple,
)

LOG = logging.getLogger(__name__)

__all__ = ["open_dataset", "open_zarr", "debug_zarr_loading"]

DEBUG_ZARR_LOADING = int(os.environ.get("DEBUG_ZARR_LOADING", "0"))

DEPTH = 0


def debug_indexing(method):
def wrapper(self, index):
global DEPTH
if isinstance(index, tuple):
print(" " * DEPTH, "->", self, method.__name__, index)
DEPTH += 1
result = method(self, index)
DEPTH -= 1
if isinstance(index, tuple):
print(" " * DEPTH, "<-", self, method.__name__, result.shape)
return result

return wrapper


def debug_zarr_loading(on_off):
global DEBUG_ZARR_LOADING
Expand Down Expand Up @@ -192,6 +214,7 @@ def metadata_specific(self, **kwargs):
def __repr__(self):
return self.__class__.__name__ + "()"

@debug_indexing
def _get_tuple(self, n):
raise NotImplementedError(
f"Tuple not supported: {n} (class {self.__class__.__name__})"
Expand Down Expand Up @@ -344,6 +367,7 @@ def __init__(self, path):
def __len__(self):
return self.data.shape[0]

@debug_indexing
def __getitem__(self, n):
if isinstance(n, tuple) and any(not isinstance(i, (int, slice)) for i in n):
return self._getitem_extended(n)
Expand Down Expand Up @@ -638,6 +662,7 @@ class Concat(Combined):
def __len__(self):
return sum(len(i) for i in self.datasets)

@debug_indexing
def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
result = []
Expand All @@ -661,6 +686,7 @@ def _get_tuple(self, index):

return apply_index_to_slices_changes(np.concatenate(result, axis=0), changes)

@debug_indexing
def __getitem__(self, n):
if isinstance(n, tuple):
return self._get_tuple(n)
Expand All @@ -675,6 +701,7 @@ def __getitem__(self, n):
k += 1
return self.datasets[k][n]

@debug_indexing
def _get_slice(self, s):
result = []

Expand Down Expand Up @@ -742,24 +769,30 @@ def shape(self):
assert False not in result, result
return result

@debug_indexing
def _get_tuple(self, index):
print(index, self.shape)
index, changes = index_to_slices(index, self.shape)
selected = index[self.axis]
lengths = [d.shape[self.axis] for d in self.datasets]
slices = length_to_slices(selected, lengths)
print("per_dataset_index", slices)
slices = length_to_slices(index[self.axis], lengths)

result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]
print("SLICES", slices, self.axis, index, lengths)
before = index[: self.axis]

x = tuple([slice(None)] * self.axis + [selected])
result = [
d[before + (i,)] for (d, i) in zip(self.datasets, slices) if i is not None
]
print([d.shape for d in result])
result = np.concatenate(result, axis=self.axis)
print(result.shape)

return apply_index_to_slices_changes(
np.concatenate(result, axis=self.axis)[x], changes
)
return apply_index_to_slices_changes(result, changes)

@debug_indexing
def _get_slice(self, s):
return np.stack([self[i] for i in range(*s.indices(self._len))])

@debug_indexing
def __getitem__(self, n):
if isinstance(n, tuple):
return self._get_tuple(n)
Expand Down Expand Up @@ -810,42 +843,22 @@ def check_same_variables(self, d1, d2):
def __len__(self):
return len(self.datasets[0])

@debug_indexing
def _get_tuple(self, index):
print("Join._get_tuple", index)
assert len(index) > 1, index

index, changes = index_to_slices(index, self.shape)

selected_variables = index[1]

index = list(index)
index[1] = slice(None)
index = tuple(index)
print("Join._get_tuple", index)
index, previous = update_tuple(index, 1, slice(None))

# TODO: optimize if index does not access all datasets, so we don't load chunks we don't need
result = [d[index] for d in self.datasets]

print(
"Join._get_tuple",
self.shape,
[r.shape for r in result],
selected_variables,
changes,
)
result = np.concatenate(result, axis=1)
print("Join._get_tuple", result.shape)

# raise NotImplementedError()

# result = np.concatenate(result)
# result = np.stack(result)

return apply_index_to_slices_changes(result[:, selected_variables], changes)
return apply_index_to_slices_changes(result[:, previous], changes)

@debug_indexing
def _get_slice(self, s):
return np.stack([self[i] for i in range(*s.indices(self._len))])

@debug_indexing
def __getitem__(self, n):
if isinstance(n, tuple):
return self._get_tuple(n)
Expand Down Expand Up @@ -931,10 +944,14 @@ def __init__(self, dataset, indices):

self.dataset = dataset
self.indices = list(indices)
self.slice = _make_slice_or_index_from_list_or_tuple(self.indices)
assert isinstance(self.slice, slice)
print("SUBSET", self.slice)

# Forward other properties to the super dataset
super().__init__(dataset)

@debug_indexing
def __getitem__(self, n):
if isinstance(n, tuple):
return self._get_tuple(n)
Expand All @@ -945,25 +962,22 @@ def __getitem__(self, n):
n = self.indices[n]
return self.dataset[n]

@debug_indexing
def _get_slice(self, s):
# TODO: check if the indices can be simplified to a slice
# the time checking maybe be longer than the time saved
# using a slice
indices = [self.indices[i] for i in range(*s.indices(self._len))]
return np.stack([self.dataset[i] for i in indices])

@debug_indexing
def _get_tuple(self, n):
first, rest = n[0], n[1:]

if isinstance(first, int):
return self.dataset[(self.indices[first],) + rest]

if isinstance(first, slice):
indices = tuple(self.indices[i] for i in range(*first.indices(self._len)))
indices = _make_slice_or_index_from_list_or_tuple(indices)
return self.dataset[(indices,) + rest]

raise NotImplementedError(f"Only int and slice supported not {type(first)}")
index, changes = index_to_slices(n, self.shape)
index, previous = update_tuple(index, 0, self.slice)
result = self.dataset[index]
result = result[previous]
result = apply_index_to_slices_changes(result, changes)
return result

def __len__(self):
return len(self.indices)
Expand Down Expand Up @@ -1003,6 +1017,17 @@ def __init__(self, dataset, indices):
# Forward other properties to the main dataset
super().__init__(dataset)

@debug_indexing
def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
index, previous = update_tuple(index, 1, slice(None))
result = self.dataset[index]
result = result[:, self.indices]
result = result[:, previous]
result = apply_index_to_slices_changes(result, changes)
return result

@debug_indexing
def __getitem__(self, n):
if isinstance(n, tuple):
return self._get_tuple(n)
Expand Down
33 changes: 17 additions & 16 deletions ecml_tools/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,28 +82,29 @@ def length_to_slices(index, lengths):
"""
total = sum(lengths)
start, stop, step = index.indices(total)
print(start, stop, step)

# TODO: combine loops
p = []
result = []

pos = 0
for length in lengths:
end = pos + length
p.append((pos, end))
pos = end

result = []
b = max(pos, start)
e = min(end, stop)

for i, (s, e) in enumerate(p):
pos = s
if s % step:
s = s + step - s % step
assert s % step == 0
assert s >= pos
if max(s, start) <= min(e, stop):
result.append((i, slice(s - pos, e - pos, step)))
else:
result.append(None)
p = None
if b <= e:
if (b - start) % step != 0:
b = b + step - (b - start) % step
b -= pos
e -= pos

if 0 <= b < e:
p = slice(b, e, step)

result.append(p)

pos = end

return result

Expand Down
22 changes: 16 additions & 6 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,16 @@ def __init__(self, ds):
self.np = ds[:] # Numpy array

assert self.ds.shape == self.np.shape
assert (self.ds == self.np).all()

def __getitem__(self, index):
assert (self.ds[index] == self.np[index]).all()
if self.ds[index] is None:
assert False, (self.ds, index)

if not (self.ds[index] == self.np[index]).all():
# print("DS", self.ds[index])
# print("NP", self.np[index])
assert (self.ds[index] == self.np[index]).all()


def indexing(ds):
Expand Down Expand Up @@ -310,7 +317,7 @@ def test_concat():
def test_join_1():
ds = open_dataset(
"test-2021-2021-6h-o96-abcd",
"test-2021-2021-6h-o96-efg",
"test-2021-2021-6h-o96-efgh",
)

assert isinstance(ds, Join)
Expand All @@ -330,6 +337,7 @@ def test_join_1():
_(date, "e"),
_(date, "f"),
_(date, "g"),
_(date, "h"),
]
)
assert (row == expect).all()
Expand All @@ -338,7 +346,7 @@ def test_join_1():

assert (ds.dates == np.array(dates, dtype="datetime64")).all()

assert ds.variables == ["a", "b", "c", "d", "e", "f", "g"]
assert ds.variables == ["a", "b", "c", "d", "e", "f", "g", "h"]
assert ds.name_to_index == {
"a": 0,
"b": 1,
Expand All @@ -347,9 +355,10 @@ def test_join_1():
"e": 4,
"f": 5,
"g": 6,
"h": 7,
}

assert ds.shape == (365 * 4, 7, 1, VALUES)
assert ds.shape == (365 * 4, 8, 1, VALUES)

same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd")
slices(ds)
Expand Down Expand Up @@ -1275,6 +1284,7 @@ def test_ensemble_1():

dates = []
date = datetime.datetime(2021, 1, 1)
indexing(ds)

for row in ds:
expect = make_row(
Expand All @@ -1299,7 +1309,7 @@ def test_ensemble_1():
assert ds.shape == (365 * 4, 4, 11, VALUES)
# same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd")
slices(ds)
indexing(ds)

metadata(ds)


Expand Down Expand Up @@ -1432,4 +1442,4 @@ def test_statistics():


if __name__ == "__main__":
test_constructor_5()
test_ensemble_1()

0 comments on commit 0fd9c4f

Please sign in to comment.