Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 6, 2024
1 parent a731eb1 commit 9a54d4f
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 76 deletions.
85 changes: 27 additions & 58 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import os
import re
import warnings
from functools import cached_property
from functools import cached_property, wraps
from pathlib import PurePath

import numpy as np
Expand All @@ -22,6 +22,7 @@

from .indexing import (
apply_index_to_slices_changes,
expand_list_indexing,
index_to_slices,
length_to_slices,
update_tuple,
Expand All @@ -37,6 +38,7 @@


def _debug_indexing(method):
@wraps(method)
def wrapper(self, index):
global DEPTH
if isinstance(index, tuple):
Expand Down Expand Up @@ -224,6 +226,7 @@ def __repr__(self):
return self.__class__.__name__ + "()"

@debug_indexing
@expand_list_indexing
def _get_tuple(self, n):
raise NotImplementedError(
f"Tuple not supported: {n} (class {self.__class__.__name__})"
Expand Down Expand Up @@ -381,30 +384,10 @@ def __len__(self):
return self.data.shape[0]

@debug_indexing
@expand_list_indexing
def __getitem__(self, n):
if isinstance(n, tuple) and any(not isinstance(i, (int, slice)) for i in n):
return self._getitem_extended(n)

return self.data[n]

def _getitem_extended(self, index):
"""
Allows to use slices, lists, and tuples to select data from the dataset.
Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
"""

assert False, index

shape = self.data.shape

axes = []
data = []
for n in self._unwind(index[0], index[1:], shape, 0, axes):
data.append(self.data[n])

assert len(axes) == 1, axes # Not implemented for more than one axis
return np.concatenate(data, axis=axes[0])

def _unwind(self, index, rest, shape, axis, axes):
if not isinstance(index, (int, slice, list, tuple)):
try:
Expand Down Expand Up @@ -676,28 +659,20 @@ def __len__(self):
return sum(len(i) for i in self.datasets)

@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
result = []

first, rest = index[0], index[1:]
start, stop, step = first.start, first.stop, first.step

for d in self.datasets:
length = d._len

result.append(d[(slice(start, stop, step),) + rest])

start -= length
while start < 0:
start += step

stop -= length

if start > stop:
break

return apply_index_to_slices_changes(np.concatenate(result, axis=0), changes)
print(index, changes)
lengths = [d.shape[0] for d in self.datasets]
slices = length_to_slices(index[0], lengths)
print("slies", slices)
result = [
d[update_tuple(index, 0, i)[0]]
for (d, i) in zip(self.datasets, slices)
if i is not None
]
result = np.concatenate(result, axis=0)
return apply_index_to_slices_changes(result, changes)

@debug_indexing
def __getitem__(self, n):
Expand All @@ -718,21 +693,10 @@ def __getitem__(self, n):
def _get_slice(self, s):
result = []

start, stop, step = s.indices(self._len)

for d in self.datasets:
length = d._len

result.append(d[start:stop:step])

start -= length
while start < 0:
start += step

stop -= length
lengths = [d.shape[0] for d in self.datasets]
slices = length_to_slices(s, lengths)

if start > stop:
break
result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]

return np.concatenate(result)

Expand Down Expand Up @@ -783,13 +747,15 @@ def shape(self):
return result

@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
lengths = [d.shape[self.axis] for d in self.datasets]
slices = length_to_slices(index[self.axis], lengths)
before = index[: self.axis]
result = [
d[before + (i,)] for (d, i) in zip(self.datasets, slices) if i is not None
d[update_tuple(index, self.axis, i)[0]]
for (d, i) in zip(self.datasets, slices)
if i is not None
]
result = np.concatenate(result, axis=self.axis)
return apply_index_to_slices_changes(result, changes)
Expand Down Expand Up @@ -850,6 +816,7 @@ def __len__(self):
return len(self.datasets[0])

@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
index, previous = update_tuple(index, 1, slice(None))
Expand Down Expand Up @@ -977,6 +944,7 @@ def _get_slice(self, s):
return np.stack([self.dataset[i] for i in indices])

@debug_indexing
@expand_list_indexing
def _get_tuple(self, n):
index, changes = index_to_slices(n, self.shape)
index, previous = update_tuple(index, 0, self.slice)
Expand Down Expand Up @@ -1024,6 +992,7 @@ def __init__(self, dataset, indices):
super().__init__(dataset)

@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
index, previous = update_tuple(index, 1, slice(None))
Expand Down
62 changes: 48 additions & 14 deletions ecml_tools/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# nor does it submit to any jurisdiction.


from functools import wraps

import numpy as np


Expand Down Expand Up @@ -109,21 +111,53 @@ def length_to_slices(index, lengths):
return result


class IndexTester:
def __init__(self, shape):
self.shape = shape
def _as_tuples(index):
def _(i):
if hasattr(i, "tolist"):
# NumPy arrays, TensorFlow tensors, etc.
i = i.tolist()
assert not isinstance(i[0], bool), "Mask not supported"
return tuple(i)

if isinstance(i, list):
return tuple(i)

return i

return tuple(_(i) for i in index)


def expand_list_indexing(method):
"""
Allows to use slices, lists, and tuples to select data from the dataset.
Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
"""

@wraps(method)
def wrapper(self, index):
if not isinstance(index, tuple):
return method(self, index)

if not any(isinstance(i, (list, tuple)) for i in index):
return method(self, index)

which = []
for i, idx in enumerate(index):
if isinstance(idx, (list, tuple)):
which.append(i)

assert which, "No list index found"

def __getitem__(self, index):
return index_to_slices(index, self.shape)
if len(which) > 1:
raise IndexError("Only one list index is allowed")

which = which[0]
index = _as_tuples(index)
result = []
for i in index[which]:
index, _ = update_tuple(index, which, slice(i, i + 1))
result.append(method(self, index))

if __name__ == "__main__":
t = IndexTester((1000, 8, 10, 20000))
i = t[0, 1, 2, 3]
print(i)
return np.concatenate(result, axis=which)

# print(t[0])
# print(t[0, 1, 2, 3])
# print(t[0:10])
# print(t[...])
# print(t[:-1])
return wrapper
10 changes: 6 additions & 4 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,16 @@ def indexing(ds):
t[0:10, 0:3, 0]
t[:, :, :]

# t[:, (1, 3), :]
# t[:, (1, 3)]
t[:, (1, 3), :]
t[:, (1, 3)]

t[0]
t[0, :]
t[0, 0, :]
t[0, 0, 0, :]

# if ds.shape[2] > 1: # Ensemble dimension
# t[0:10, :, (0, 1)]
if ds.shape[2] > 1: # Ensemble dimension
t[0:10, :, (0, 1)]


def slices(ds, start=None, end=None, step=None):
Expand Down Expand Up @@ -1134,6 +1134,8 @@ def test_ensemble_1():
)
ds = test.ds

ds[0:10,:,(1,2)]

assert isinstance(ds, Ensemble)
assert len(ds) == 365 * 1 * 4
assert len([row for row in ds]) == len(ds)
Expand Down

0 comments on commit 9a54d4f

Please sign in to comment.