Skip to content

Commit 34a79df

Browse files
committed
test and add support for channels last and its interaction with slicing operations
1 parent 06653e2 commit 34a79df

File tree

3 files changed

+170
-8
lines changed

3 files changed

+170
-8
lines changed

funlib/persistence/arrays/array.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,22 @@ def chunk_shape(self) -> Coordinate:
9797

9898
def uncollapsed_dims(self, physical: bool = False) -> list[bool]:
9999
if physical:
100-
return self._uncollapsed_dims[-self._metadata.voxel_size.dims :]
100+
return [
101+
x
102+
for x, c in zip(self._uncollapsed_dims, self._metadata.axis_names)
103+
if not c.endswith("^")
104+
]
101105
else:
102106
return self._uncollapsed_dims
103107

104108
@property
105109
def offset(self) -> Coordinate:
106110
"""Get the offset of this array in world units."""
111+
udims = self.uncollapsed_dims(physical=True)
107112
return Coordinate(
108113
[
109114
self._metadata.offset[ii]
110-
for ii, uncollapsed in enumerate(self.uncollapsed_dims(physical=True))
115+
for ii, uncollapsed in enumerate(udims)
111116
if uncollapsed
112117
]
113118
)
@@ -119,10 +124,11 @@ def offset(self, offset: Iterable[int]) -> None:
119124
@property
120125
def voxel_size(self) -> Coordinate:
121126
"""Get the size of a voxel in world units."""
127+
udims = self.uncollapsed_dims(physical=True)
122128
return Coordinate(
123129
[
124130
self._metadata.voxel_size[ii]
125-
for ii, uncollapsed in enumerate(self.uncollapsed_dims(physical=True))
131+
for ii, uncollapsed in enumerate(udims)
126132
if uncollapsed
127133
]
128134
)
@@ -133,9 +139,10 @@ def voxel_size(self, voxel_size: Iterable[int]) -> None:
133139

134140
@property
135141
def units(self) -> list[str]:
142+
udims = self.uncollapsed_dims(physical=True)
136143
return [
137144
self._metadata.units[ii]
138-
for ii, uncollapsed in enumerate(self.uncollapsed_dims(physical=True))
145+
for ii, uncollapsed in enumerate(udims)
139146
if uncollapsed
140147
]
141148

@@ -145,6 +152,7 @@ def units(self, units: list[str]) -> None:
145152

146153
@property
147154
def axis_names(self) -> list[str]:
155+
print(self._metadata.axis_names, self._uncollapsed_dims)
148156
return [
149157
self._metadata.axis_names[ii]
150158
for ii, uncollapsed in enumerate(self.uncollapsed_dims(physical=False))
@@ -205,7 +213,12 @@ def apply_adapter(self, adapter: Adapter):
205213
adapter = (adapter,)
206214
for ii, a in enumerate(adapter):
207215
if isinstance(a, int):
208-
self._uncollapsed_dims[ii] = False
216+
for i, uc in enumerate(self._uncollapsed_dims):
217+
if uc:
218+
if ii == 0:
219+
self._uncollapsed_dims[i] = False
220+
break
221+
ii -= 1
209222
self.data = self.data[adapter]
210223
elif callable(adapter):
211224
self.data = adapter(self.data)

tests/test_array.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,80 @@ def test_slicing():
273273

274274
with pytest.raises(RuntimeError):
275275
a[:, :] = np.array([42, 43, 44]).reshape(3, 1)
276+
277+
278+
def test_slicing_channel_dim_last():
279+
a = Array(
280+
np.arange(0, 4 * 4).reshape(2, 2, 4),
281+
(0, 0),
282+
(1, 1),
283+
axis_names=["d0", "d1", "c0^"],
284+
)
285+
286+
a.adapt(np.s_[1, :, 0:3])
287+
assert a.shape == (2, 3)
288+
assert a.axis_names == ["d1", "c0^"], a.axis_names
289+
assert a.units == [""]
290+
291+
a.adapt(np.s_[:, 2])
292+
assert a.shape == (2,)
293+
assert a.axis_names == ["d1"]
294+
assert a.units == [""]
295+
296+
a[:] = 42
297+
298+
assert all([x == 42 for x in a._source_data[1, :, 2]]), a._source_data[1, :, 2]
299+
300+
# test with list indexing
301+
a = Array(
302+
np.arange(0, 4 * 4).reshape(2, 2, 4),
303+
(0, 0),
304+
(1, 1),
305+
axis_names=["d0", "d1", "c0^"],
306+
)
307+
308+
a.adapt(np.s_[[0, 1], 1, :])
309+
assert a.shape == (2, 4)
310+
assert a.axis_names == ["d0", "c0^"]
311+
assert a.units == [""]
312+
313+
a.adapt(np.s_[1, :])
314+
assert a.shape == (4,)
315+
assert a.axis_names == ["c0^"]
316+
assert a.units == []
317+
318+
a[:] = 42
319+
320+
assert all([x == 42 for x in a._source_data[1, 1, :]]), a._source_data[1, 1, :]
321+
322+
# test weird case
323+
a = Array(
324+
np.arange(0, 4 * 4).reshape(4, 2, 2),
325+
(0, 0),
326+
(1, 1),
327+
axis_names=["d0", "d1", "c0^"],
328+
)
329+
330+
a.adapt(np.s_[[2, 2, 2], 1, :])
331+
assert a.shape == (3, 2)
332+
assert a.axis_names == ["d0", "c0^"]
333+
assert a.units == [""]
334+
335+
a[:, :] = np.array([42, 43, 44]).reshape(3, 1)
336+
assert all([x == 44 for x in a._source_data[2, 1, :]]), a._source_data[2, 1, :]
337+
338+
# test_bool_indexing
339+
a = Array(
340+
np.arange(0, 4 * 4).reshape(2, 2, 4),
341+
(0, 0),
342+
(1, 1),
343+
axis_names=["d0", "d1", "c0^"],
344+
)
345+
346+
a.adapt(np.s_[1, :, np.array([True, True, True, False])])
347+
assert a.shape == (2, 3)
348+
assert a.axis_names == ["d1", "c0^"]
349+
assert a.units == [""]
350+
351+
with pytest.raises(RuntimeError):
352+
a[:, :] = np.array([42, 43, 44]).reshape(3, 1)

tests/test_datasets.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from funlib.persistence.arrays.metadata import MetaDataFormat
2-
from funlib.persistence.arrays.datasets import open_ds, prepare_ds
2+
from funlib.persistence.arrays.datasets import open_ds, prepare_ds, ArrayNotFoundError
33
from funlib.geometry import Coordinate, Roi
44

5-
from zarr.errors import ArrayNotFoundError
65
import numpy as np
76

87
import pytest
@@ -30,7 +29,7 @@ def test_helpers(tmpdir, store, dtype):
3029
"voxel_size": [1, 2, 3],
3130
"axis_names": ["sample^", "channel^", "z", "y", "x"],
3231
"units": ["nm", "nm", "nm"],
33-
}
32+
},
3433
)
3534

3635
# test prepare_ds fails if array does not exist and mode is read
@@ -184,3 +183,76 @@ def test_helpers(tmpdir, store, dtype):
184183
assert array.offset == metadata.offset
185184
assert array.axis_names == metadata.axis_names
186185
assert array.units == metadata.units
186+
187+
188+
@pytest.mark.parametrize("store", stores.keys())
189+
@pytest.mark.parametrize("dtype", [np.float32, np.uint8, np.uint64])
190+
def test_open_ds(tmpdir, store, dtype):
191+
shape = Coordinate(1, 1, 10, 20, 30)
192+
store = tmpdir / store
193+
metadata = MetaDataFormat().parse(
194+
shape,
195+
{
196+
"offset": [100, 200, 400],
197+
"voxel_size": [1, 2, 3],
198+
"axis_names": ("sample^", "channel^", "z", "y", "x"),
199+
"units": ("nm", "nm", "nm"),
200+
},
201+
)
202+
203+
# test open_ds fails if array does not exist and mode is read
204+
with pytest.raises(ArrayNotFoundError):
205+
open_ds(
206+
store,
207+
offset=metadata.offset,
208+
voxel_size=metadata.voxel_size,
209+
axis_names=metadata.axis_names,
210+
units=metadata.units,
211+
mode="r",
212+
)
213+
214+
# test open_ds creates array if it does not exist and mode is write
215+
array = prepare_ds(
216+
store,
217+
shape,
218+
offset=metadata.offset,
219+
voxel_size=metadata.voxel_size,
220+
axis_names=metadata.axis_names,
221+
units=metadata.units,
222+
dtype=dtype,
223+
mode="w",
224+
)
225+
assert array.roi == Roi(
226+
metadata.offset, metadata.voxel_size * Coordinate(*shape[-3:])
227+
)
228+
assert array.voxel_size == metadata.voxel_size
229+
assert array.offset == metadata.offset
230+
assert array.axis_names == metadata.axis_names
231+
assert array.units == metadata.units
232+
233+
# test open_ds opens array if it exists and mode is read
234+
array = open_ds(
235+
store,
236+
offset=metadata.offset,
237+
voxel_size=metadata.voxel_size,
238+
axis_names=metadata.axis_names,
239+
units=metadata.units,
240+
mode="r",
241+
)
242+
assert array.roi == Roi(
243+
metadata.offset, metadata.voxel_size * Coordinate(*shape[-3:])
244+
)
245+
assert array.voxel_size == metadata.voxel_size
246+
assert array.offset == metadata.offset
247+
assert array.axis_names == metadata.axis_names
248+
assert array.units == metadata.units
249+
250+
# test open_ds fails if array exists and is opened in read mode
251+
# with incompatible arguments
252+
array = open_ds(
253+
store,
254+
offset=(1, 2, 3),
255+
voxel_size=(1, 2, 3),
256+
axis_names=None,
257+
units=None,
258+
)

0 commit comments

Comments
 (0)