Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
Work on indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 5, 2024
1 parent 61eb457 commit cad646a
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 32 deletions.
68 changes: 55 additions & 13 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import ecml_tools

from .indexing import apply_index_to_slices_changes, index_to_slices

LOG = logging.getLogger(__name__)

__all__ = ["open_dataset", "open_zarr", "debug_zarr_loading"]
Expand Down Expand Up @@ -55,11 +57,20 @@ def _tuple_with_slices(t):
"""

result = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in t)
changes = [j for (j, i) in enumerate(t) if isinstance(i, int)]
changes = tuple(j for (j, i) in enumerate(t) if isinstance(i, int))

return result, changes


def _apply_tuple_changes(result, changes):
if changes:
shape = result.shape
for i in changes:
assert shape[i] == 1, shape
result = np.squeeze(result, axis=changes)
return result


class Dataset:
arguments = {}

Expand Down Expand Up @@ -367,7 +378,6 @@ def _getitem_extended(self, index):

assert False, index


shape = self.data.shape

axes = []
Expand Down Expand Up @@ -648,6 +658,29 @@ class Concat(Combined):
def __len__(self):
return sum(len(i) for i in self.datasets)

def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
result = []

first, rest = index[0], index[1:]
start, stop, step = first.start, first.stop, first.step

for d in self.datasets:
length = d._len

result.append(d[(slice(start, stop, step),) + rest])

start -= length
while start < 0:
start += step

stop -= length

if start > stop:
break

return apply_index_to_slices_changes(np.concatenate(result, axis=0), changes)

def __getitem__(self, n):
if isinstance(n, tuple):
return self._get_tuple(n)
Expand Down Expand Up @@ -783,28 +816,37 @@ def __len__(self):
return len(self.datasets[0])

def _get_tuple(self, index):
print("Join._get_tuple", index)
assert len(index) > 1, index

selected_variables = index[1]

index, changed = _tuple_with_slices(index)

selected_variables = index[1]

index = list(index)
index[1] = slice(None)
index = tuple(index)
print("Join._get_tuple", index)

# TODO: optimize if index does not access all datasets, so we don't load chunks we don't need
result = [d[index] for d in self.datasets]

print(self.shape, [r.shape for r in result], selected_variables, changed)
result = np.stack(result)
print(result.shape)
print(
"Join._get_tuple",
self.shape,
[r.shape for r in result],
selected_variables,
changed,
)
result = np.concatenate(result, axis=1)
print("Join._get_tuple", result.shape)

raise NotImplementedError()
# raise NotImplementedError()

result = np.concatenate(result)
# result = np.concatenate(result)
# result = np.stack(result)

return result[index[1]]
return _apply_tuple_changes(result[:, selected_variables], changed)

def _get_slice(self, s):
return np.stack([self[i] for i in range(*s.indices(self._len))])
Expand Down Expand Up @@ -967,11 +1009,11 @@ def __init__(self, dataset, indices):
super().__init__(dataset)

def __getitem__(self, n):
# if isinstance(n, tuple):
# return self._get_tuple(n)
if isinstance(n, tuple):
return self._get_tuple(n)

row = self.dataset[n]
if isinstance(n, (slice, tuple)):
if isinstance(n, slice):
return row[:, self.indices]

return row[self.indices]
Expand Down
84 changes: 84 additions & 0 deletions ecml_tools/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import numpy as np


def _tuple_with_slices(t, shape):
"""
Replace all integers in a tuple with slices, so we preserve the dimensionality.
"""

result = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in t)
changes = tuple(j for (j, i) in enumerate(t) if isinstance(i, int))
result = tuple(slice(*s.indices(shape[i])) for (i, s) in enumerate(result))

return result, changes


def _extend_shape(index, shape):
if Ellipsis in index:
if index.count(Ellipsis) > 1:
raise IndexError("Only one Ellipsis is allowed")
ellipsis_index = index.index(Ellipsis)
index = list(index)
index[ellipsis_index] = slice(None)
while len(index) < len(shape):
index.insert(ellipsis_index, slice(None))
index = tuple(index)

while len(index) < len(shape):
index = index + (slice(None),)

return index


def _index_to_tuple(index, shape):
if isinstance(index, int):
return _extend_shape((index,), shape)
if isinstance(index, slice):
return _extend_shape((index,), shape)
if isinstance(index, tuple):
return _extend_shape(index, shape)
if index is Ellipsis:
return _extend_shape((Ellipsis,), shape)
raise ValueError(f"Invalid index: {index}")


def index_to_slices(index, shape):
"""
Convert an index to a tuple of slices, with the same dimensionality as the shape.
"""
return _tuple_with_slices(_index_to_tuple(index, shape), shape)


def apply_index_to_slices_changes(result, changes):
if changes:
shape = result.shape
for i in changes:
assert shape[i] == 1, shape
result = np.squeeze(result, axis=changes)
return result


class IndexTester:
def __init__(self, shape):
self.shape = shape

def __getitem__(self, index):
return index_to_slices(index, self.shape)


if __name__ == "__main__":
t = IndexTester((1000, 8, 10, 20000))

print(t[0])
print(t[0, 1, 2, 3])
print(t[0:10])
print(t[...])
print(t[:-1])
Loading

0 comments on commit cad646a

Please sign in to comment.