3
3
from .adapters import Adapter
4
4
import numpy as np
5
5
import dask .array as da
6
+ from functools import reduce
7
+ from dask .array .optimization import fuse_slice
6
8
7
9
from typing import Optional , Iterable , Any , Union
8
10
@@ -295,7 +297,6 @@ def __setitem__(self, key, value: np.ndarray):
295
297
296
298
region_slices = self .__slices (roi )
297
299
298
-
299
300
da .store (
300
301
self .data [roi_slices ], self ._source_data , regions = region_slices
301
302
)
@@ -306,7 +307,7 @@ def __setitem__(self, key, value: np.ndarray):
306
307
adapter for adapter in self .adapters if self ._is_slice (adapter )
307
308
]
308
309
309
- region_slices = self . _combine_slices ( * adapter_slices , key )
310
+ region_slices = reduce ( fuse_slice , [ * adapter_slices , key ] )
310
311
311
312
da .store (self .data [key ], self ._source_data , regions = region_slices )
312
313
@@ -352,63 +353,6 @@ def to_ndarray(self, roi, fill_value=0):
352
353
353
354
return data
354
355
355
- def _combine_slices (
356
- self , * roi_slices : list [Union [tuple [slice ], slice ]]
357
- ) -> list [slice ]:
358
- """Combine slices into a single slice."""
359
- # if there are multiple slices, then we are using adapters
360
- # this is important because if we are considering the adapter slices
361
- # we need to use the shape of the source data, not the adapted data
362
- use_adapters = len (roi_slices ) > 1
363
- roi_slices = [
364
- roi_slice if isinstance (roi_slice , tuple ) else (roi_slice ,)
365
- for roi_slice in roi_slices
366
- ]
367
- num_dims = max ([len (roi_slice ) for roi_slice in roi_slices ])
368
-
369
- remaining_dims = list (range (num_dims ))
370
- combined_ranges = [
371
- (
372
- range (0 , self .shape [d ], 1 )
373
- if not use_adapters
374
- else range (0 , self ._source_data .shape [d ], 1 )
375
- )
376
- for d in range (num_dims )
377
- ]
378
- combined_slices = []
379
-
380
- for roi_slice in roi_slices :
381
- dim_slices = [roi_slice [d ] for d in range (num_dims ) if len (roi_slice ) > d ]
382
-
383
- del_dims = []
384
- for d , s in enumerate (dim_slices ):
385
- current_dimension = remaining_dims [d ]
386
- combined_ranges [current_dimension ] = combined_ranges [current_dimension ][
387
- s
388
- ]
389
- if isinstance (s , int ):
390
- del_dims .append (d )
391
- for d in del_dims :
392
- del remaining_dims [d ]
393
-
394
- for combined_range in combined_ranges :
395
- if isinstance (combined_range , int ):
396
- combined_slices .append (combined_range )
397
- elif len (combined_range ) == 0 :
398
- combined_slices .append (slice (0 ))
399
- elif combined_range .stop < 0 :
400
- combined_slices .append (
401
- slice (combined_range .start , None , combined_range .step )
402
- )
403
- else :
404
- combined_slices .append (
405
- slice (
406
- combined_range .start , combined_range .stop , combined_range .step
407
- )
408
- )
409
-
410
- return tuple (combined_slices )
411
-
412
356
def __slices (self , roi , use_adapters : bool = True , check_chunk_align : bool = False ):
413
357
"""Get the voxel slices for the given roi."""
414
358
@@ -437,7 +381,7 @@ def __slices(self, roi, use_adapters: bool = True, check_chunk_align: bool = Fal
437
381
else []
438
382
)
439
383
440
- combined_slice = self . _combine_slices ( * adapter_slices , roi_slices )
384
+ combined_slice = reduce ( fuse_slice , [ * adapter_slices , roi_slices ] )
441
385
442
386
return combined_slice
443
387
@@ -448,9 +392,9 @@ def _is_slice(self, adapter: Adapter):
448
392
or isinstance (adapter , list )
449
393
):
450
394
return True
451
- elif isinstance (adapter , tuple ) and all (
452
- [ isinstance ( a , slice ) or isinstance ( a , int ) for a in adapter ]
453
- ) :
395
+ elif isinstance (adapter , tuple ) and all ([ self . _is_slice ( a ) for a in adapter ]):
396
+ return True
397
+ elif isinstance ( adapter , np . ndarray ) and adapter . dtype == bool :
454
398
return True
455
399
return False
456
400
0 commit comments