Skip to content

Commit

Permalink
improve slicing support
Browse files Browse the repository at this point in the history
allow integer indexing to eliminate dimensions
  • Loading branch information
pattonw committed Jun 21, 2024
1 parent 43e84f7 commit dc220c4
Showing 1 changed file with 171 additions and 52 deletions.
223 changes: 171 additions & 52 deletions funlib/persistence/arrays/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ class Array(Freezable):
The size of a chunk of the underlying data container in voxels.
adapters (``Optional[Union[Adapter, list[Adapter]]]``):
adapter (``Optional[Adapter]``):
The adapter or list of adapters to use for this array.
The adapter to use for this array. If you would like apply multiple
adapters, please look into either the `.adapt` method or the
`SequentialAdapter` class.
"""

Expand All @@ -52,7 +54,7 @@ class Array(Freezable):
axis_names: list[str]
units: list[str]
chunk_shape: Coordinate
adapter: list[Adapter]
adapter: Adapter

def __init__(
self,
Expand All @@ -65,37 +67,91 @@ def __init__(
adapter: Optional[Union[Adapter, Iterable[Adapter]]] = None,
):
self.data = da.from_array(data)
self._uncollapsed_dims = [True for _ in self.data.shape]
self.voxel_size = (
Coordinate(voxel_size) if voxel_size is not None else (1,) * len(data.shape)
voxel_size if voxel_size is not None else (1,) * len(data.shape)
)
self.offset = (
Coordinate(offset) if offset is not None else (0,) * len(data.shape)
)
# assign default axis names, if not given
if axis_names is None:
channel_names = [f"c{i}^" for i in range(self.channel_dims)]
spatial_names = [f"d{i}" for i in range(self.spatial_dims)]
axis_names = channel_names + spatial_names
self.axis_names = tuple(axis_names)
# assign unknown unit to each spatial dim, if not given
self.units = (
tuple(units) if units is not None else ("",) * self.spatial_dims
self.offset = offset if offset is not None else (0,) * len(data.shape)
self.axis_names = (
axis_names
if axis_names is not None
else tuple(f"c{i}^" for i in range(self.channel_dims))
+ tuple(f"d{i}" for i in range(self.voxel_size.dims))
)
self.units = units if units is not None else ("",) * self.voxel_size.dims
self.chunk_shape = Coordinate(chunk_shape) if chunk_shape is not None else None
self._source_data = data

adapter = [] if adapter is None else adapter
adapter = [adapter] if not isinstance(adapter, list) else adapter
self.adapter = adapter
if adapter is not None:
self.apply_adapter(adapter)

for adapter in self.adapter:
if not isinstance(adapter, slice):
self.data = adapter(self.data)
adapters = [] if adapter is None else [adapter]
self.adapters = adapters

self.freeze()

self.validate()

def uncollapsed_dims(self, physical: bool = False) -> list[bool]:
if physical:
return self._uncollapsed_dims[-self._voxel_size.dims :]
else:
return self._uncollapsed_dims

@property
def offset(self) -> Coordinate:
"""Get the offset of this array in world units."""
return Coordinate(
[
self._offset[ii]
for ii, uncollapsed in enumerate(self.uncollapsed_dims(physical=True))
if uncollapsed
]
)

@offset.setter
def offset(self, offset: Iterable[int]) -> None:
self._offset = Coordinate(offset)

@property
def voxel_size(self) -> Coordinate:
"""Get the size of a voxel in world units."""
return Coordinate(
[
self._voxel_size[ii]
for ii, uncollapsed in enumerate(self.uncollapsed_dims(physical=True))
if uncollapsed
]
)

@voxel_size.setter
def voxel_size(self, voxel_size: Iterable[int]) -> None:
self._voxel_size = Coordinate(voxel_size)

@property
def units(self) -> list[str]:
return [
self._units[ii]
for ii, uncollapsed in enumerate(self.uncollapsed_dims(physical=True))
if uncollapsed
]

@units.setter
def units(self, units: list[str]) -> None:
self._units = list(units)

@property
def axis_names(self) -> list[str]:
return [
self._axis_names[ii]
for ii, uncollapsed in enumerate(self.uncollapsed_dims(physical=False))
if uncollapsed
]

@axis_names.setter
def axis_names(self, axis_names):
self._axis_names = list(axis_names)

@property
def roi(self):
"""
Expand All @@ -109,7 +165,7 @@ def roi(self):

@property
def dims(self):
return len(self.shape)
return sum(self.uncollapsed_dims())

@property
def channel_dims(self):
Expand All @@ -136,17 +192,38 @@ def dtype(self):

@property
def is_writeable(self):
return len(self.adapter) == 0 or all(
[
isinstance(adapter, slice)
or (
isinstance(adapter, Iterable)
and all([isinstance(a, slice) for a in adapter])
)
for adapter in self.adapter
]
return len(self.adapters) == 0 or all(
[self._is_slice(adapter) for adapter in self.adapters]
)

def apply_adapter(self, adapter: Adapter):
if self._is_slice(adapter):
if not isinstance(adapter, tuple):
adapter = (adapter,)
for ii, a in enumerate(adapter):
if isinstance(a, int):
self._uncollapsed_dims[ii] = False
self.data = self.data[adapter]
elif callable(adapter):
self.data = adapter(self.data)
else:
raise Exception(
f"Adapter {adapter} is not a supported adapter. "
f"Supported adapters are: {Adapter}"
)

def adapt(self, adapter: Adapter):
"""Apply an adapter to this array.
Args:
adapter (``Adapter``):
The adapter to apply to this array.
"""
self.apply_adapter(adapter)
self.adapters.append(adapter)

def __getitem__(self, key) -> np.ndarray:
"""Get a sub-array or a single value.
Expand Down Expand Up @@ -175,7 +252,7 @@ def __getitem__(self, key) -> np.ndarray:
% (roi, self.roi)
)

return self.data[self.__slices(roi)].compute()
return self.data[self.__slices(roi, use_adapters=False)].compute()

elif isinstance(key, Coordinate):
coordinate = key
Expand Down Expand Up @@ -212,11 +289,12 @@ def __setitem__(self, key, value: np.ndarray):
% (roi, self.roi)
)

roi_slices = self.__slices(roi)
roi_slices = self.__slices(roi, use_adapters=False)
region_slices = self.__slices(roi)

self.data[roi_slices] = value

da.store(self.data[roi_slices], self._source_data, regions=roi_slices)
da.store(self.data[roi_slices], self._source_data, regions=region_slices)
else:
raise RuntimeError(
"This array is not writeable since you have applied a custom callable "
Expand Down Expand Up @@ -260,36 +338,66 @@ def to_ndarray(self, roi, fill_value=0):
return data

def _combine_slices(
self, *roi_slices: list[Union[list[slice], slice]]
self, *roi_slices: list[Union[tuple[slice], slice]]
) -> list[slice]:
"""Combine slices into a single slice."""
# if there are multiple slices, then we are using adapters
# this is important because if we are considering the adapter slices
# we need to use the shape of the source data, not the adapted data
use_adapters = len(roi_slices) > 1
roi_slices = [
roi_slice if isinstance(roi_slice, tuple) else (roi_slice,)
for roi_slice in roi_slices
]
num_dims = max([len(roi_slice) for roi_slice in roi_slices])

remaining_dims = list(range(num_dims))
combined_ranges = [
(
range(0, self.shape[d], 1)
if not use_adapters
else range(0, self._source_data.shape[d], 1)
)
for d in range(num_dims)
]
combined_slices = []
for d in range(num_dims):

for roi_slice in roi_slices:
dim_slices = [
roi_slice[d] if len(roi_slice) > d else slice(None)
for roi_slice in roi_slices
for d in range(num_dims)
]

slice_range = range(0, self.shape[d], 1)
for s in dim_slices:
slice_range = slice_range[s]
if len(slice_range) == 0:
return slice(0)
elif slice_range.stop < 0:
return slice(slice_range.start, None, slice_range.step)
combined_slices.append(
slice(slice_range.start, slice_range.stop, slice_range.step)
)
del_dims = []
for d, s in enumerate(dim_slices):
current_dimension = remaining_dims[d]
combined_ranges[current_dimension] = combined_ranges[current_dimension][
s
]
if isinstance(s, int):
del_dims.append(d)
for d in del_dims:
del remaining_dims[d]

for combined_range in combined_ranges:
if isinstance(combined_range, int):
combined_slices.append(combined_range)
elif len(combined_range) == 0:
combined_slices.append(slice(0))
elif combined_range.stop < 0:
combined_slices.append(
slice(combined_range.start, None, combined_range.step)
)
else:
combined_slices.append(
slice(
combined_range.start, combined_range.stop, combined_range.step
)
)

return tuple(combined_slices)

def __slices(self, roi, check_chunk_align=False):
def __slices(self, roi, use_adapters: bool = True, check_chunk_align: bool = False):
"""Get the voxel slices for the given roi."""

voxel_roi = (roi - self.offset) / self.voxel_size
Expand All @@ -311,14 +419,25 @@ def __slices(self, roi, check_chunk_align=False):

roi_slices = (slice(None),) * self.channel_dims + voxel_roi.to_slices()

adapter_slices = [
adapter for adapter in self.adapter if isinstance(adapter, slice)
]
adapter_slices = (
[adapter for adapter in self.adapters if self._is_slice(adapter)]
if use_adapters
else []
)

combined_slice = self._combine_slices(roi_slices, *adapter_slices)

return combined_slice

def _is_slice(self, adapter: Adapter):
if isinstance(adapter, slice) or isinstance(adapter, int):
return True
elif isinstance(adapter, tuple) and all(
[isinstance(a, slice) or isinstance(a, int) for a in adapter]
):
return True
return False

def __index(self, coordinate):
"""Get the voxel slices for the given coordinate."""

Expand Down

0 comments on commit dc220c4

Please sign in to comment.