Skip to content

Commit fbde9a1

Browse files
committed
use dasks fuse_slice operation
1 parent 7797e1d commit fbde9a1

File tree

3 files changed

+29
-175
lines changed

3 files changed

+29
-175
lines changed

funlib/persistence/arrays/array.py

Lines changed: 7 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from .adapters import Adapter
44
import numpy as np
55
import dask.array as da
6+
from functools import reduce
7+
from dask.array.optimization import fuse_slice
68

79
from typing import Optional, Iterable, Any, Union
810

@@ -295,7 +297,6 @@ def __setitem__(self, key, value: np.ndarray):
295297

296298
region_slices = self.__slices(roi)
297299

298-
299300
da.store(
300301
self.data[roi_slices], self._source_data, regions=region_slices
301302
)
@@ -306,7 +307,7 @@ def __setitem__(self, key, value: np.ndarray):
306307
adapter for adapter in self.adapters if self._is_slice(adapter)
307308
]
308309

309-
region_slices = self._combine_slices(*adapter_slices, key)
310+
region_slices = reduce(fuse_slice, [*adapter_slices, key])
310311

311312
da.store(self.data[key], self._source_data, regions=region_slices)
312313

@@ -352,63 +353,6 @@ def to_ndarray(self, roi, fill_value=0):
352353

353354
return data
354355

355-
def _combine_slices(
356-
self, *roi_slices: list[Union[tuple[slice], slice]]
357-
) -> list[slice]:
358-
"""Combine slices into a single slice."""
359-
# if there are multiple slices, then we are using adapters
360-
# this is important because if we are considering the adapter slices
361-
# we need to use the shape of the source data, not the adapted data
362-
use_adapters = len(roi_slices) > 1
363-
roi_slices = [
364-
roi_slice if isinstance(roi_slice, tuple) else (roi_slice,)
365-
for roi_slice in roi_slices
366-
]
367-
num_dims = max([len(roi_slice) for roi_slice in roi_slices])
368-
369-
remaining_dims = list(range(num_dims))
370-
combined_ranges = [
371-
(
372-
range(0, self.shape[d], 1)
373-
if not use_adapters
374-
else range(0, self._source_data.shape[d], 1)
375-
)
376-
for d in range(num_dims)
377-
]
378-
combined_slices = []
379-
380-
for roi_slice in roi_slices:
381-
dim_slices = [roi_slice[d] for d in range(num_dims) if len(roi_slice) > d]
382-
383-
del_dims = []
384-
for d, s in enumerate(dim_slices):
385-
current_dimension = remaining_dims[d]
386-
combined_ranges[current_dimension] = combined_ranges[current_dimension][
387-
s
388-
]
389-
if isinstance(s, int):
390-
del_dims.append(d)
391-
for d in del_dims:
392-
del remaining_dims[d]
393-
394-
for combined_range in combined_ranges:
395-
if isinstance(combined_range, int):
396-
combined_slices.append(combined_range)
397-
elif len(combined_range) == 0:
398-
combined_slices.append(slice(0))
399-
elif combined_range.stop < 0:
400-
combined_slices.append(
401-
slice(combined_range.start, None, combined_range.step)
402-
)
403-
else:
404-
combined_slices.append(
405-
slice(
406-
combined_range.start, combined_range.stop, combined_range.step
407-
)
408-
)
409-
410-
return tuple(combined_slices)
411-
412356
def __slices(self, roi, use_adapters: bool = True, check_chunk_align: bool = False):
413357
"""Get the voxel slices for the given roi."""
414358

@@ -437,7 +381,7 @@ def __slices(self, roi, use_adapters: bool = True, check_chunk_align: bool = Fal
437381
else []
438382
)
439383

440-
combined_slice = self._combine_slices(*adapter_slices, roi_slices)
384+
combined_slice = reduce(fuse_slice, [*adapter_slices, roi_slices])
441385

442386
return combined_slice
443387

@@ -448,9 +392,9 @@ def _is_slice(self, adapter: Adapter):
448392
or isinstance(adapter, list)
449393
):
450394
return True
451-
elif isinstance(adapter, tuple) and all(
452-
[isinstance(a, slice) or isinstance(a, int) for a in adapter]
453-
):
395+
elif isinstance(adapter, tuple) and all([self._is_slice(a) for a in adapter]):
396+
return True
397+
elif isinstance(adapter, np.ndarray) and adapter.dtype == bool:
454398
return True
455399
return False
456400

funlib/persistence/arrays/slices.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

tests/test_slices.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,41 @@
11
import numpy as np
2-
from funlib.persistence.arrays.slices import chain_slices
2+
from dask.array.optimization import fuse_slice
3+
from functools import reduce
4+
import pytest
35

46

57
def test_slice_chaining():
68

7-
base = np.s_[::2, 0, :4]
9+
def combine_slices(*slices):
10+
return reduce(fuse_slice, slices)
11+
12+
base = np.s_[::2, :, :4]
813

914
# chain with index expressions
1015

11-
s1 = chain_slices(base, np.s_[0])
12-
assert s1 == np.s_[0, 0, :4]
16+
s1 = combine_slices(base, np.s_[0])
17+
assert s1 == np.s_[0, :, :4]
1318

14-
s2 = chain_slices(s1, np.s_[1])
15-
assert s2 == np.s_[0, 0, 1]
19+
s2 = combine_slices(s1, np.s_[1])
20+
assert s2 == np.s_[0, 1, :4]
1621

1722
# chain with index arrays
1823

19-
s1 = chain_slices(base, np.s_[[0, 1, 1, 2, 3, 5], :])
20-
assert s1 == np.s_[[0, 2, 2, 4, 6, 10], 0, :4]
24+
s1 = combine_slices(base, np.s_[[0, 1, 1, 2, 3, 5], :])
25+
assert s1 == np.s_[[0, 2, 2, 4, 6, 10], 0:, :4]
2126

2227
# ...and another index array
23-
s21 = chain_slices(s1, np.s_[[0, 3], :])
24-
assert s21 == np.s_[[0, 4], 0, :4]
28+
with pytest.raises(NotImplementedError):
29+
# this is not supported because the combined indexing
30+
# operation would not behave the same as the individual
31+
# indexing operations performed in sequence
32+
combine_slices(s1, np.s_[[0, 3], 2])
2533

2634
# ...and a slice() expression
27-
s22 = chain_slices(s1, np.s_[1:4])
28-
assert s22 == np.s_[[2, 2, 4], 0, :4]
35+
s22 = combine_slices(s1, np.s_[1:4])
36+
assert s22 == np.s_[[2, 2, 4], 0:, :4]
2937

3038
# chain with slice expressions
3139

32-
s1 = chain_slices(base, np.s_[10:20, ::2])
33-
assert s1 == np.s_[20:40:2, 0, :4:2]
40+
s1 = combine_slices(base, np.s_[10:20, ::2, 0])
41+
assert s1 == np.s_[20:40:2, 0::2, 0]

0 commit comments

Comments
 (0)