Skip to content

Commit c25ff11

Browse files
committed
Add function to chain (simple) slice expressions
1 parent ccb4b81 commit c25ff11

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

funlib/persistence/arrays/slices.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import numpy as np
2+
3+
4+
def chain_slices(slices_a, slices_b):
5+
6+
# make sure both slice expressions are tuples
7+
if not isinstance(slices_a, tuple):
8+
slices_a = (slices_a,)
9+
if not isinstance(slices_b, tuple):
10+
slices_b = (slices_b,)
11+
12+
# dimension of a is number of non-int expressions
13+
dim_a = sum([not isinstance(x, int) for x in slices_a])
14+
15+
# slices_b can't slice more dimensions than a has
16+
assert (
17+
len(slices_b) <= dim_a
18+
), f"Slice expression {slices_b} has too many dimensions to chain with {slices_a}"
19+
20+
chained = []
21+
22+
j = 0
23+
for slice_a in slices_a:
24+
25+
# if slice_a is int that dimension does not exist any longer, skip
26+
# also skip if b has no more elements
27+
if j == len(slices_b) or isinstance(slice_a, int):
28+
chained.append(slice_a)
29+
else:
30+
slice_b = slices_b[j]
31+
chained.append(_chain_slice(slice_a, slice_b))
32+
j += 1
33+
34+
return tuple(chained)
35+
36+
37+
def _chain_slice(a, b):
38+
39+
# a is a slice(start, stop, step) expression
40+
if isinstance(a, slice):
41+
42+
start_a = a.start if a.start else 0
43+
step_a = a.step if a.step else 1
44+
45+
if isinstance(b, int):
46+
47+
idx = start_a + step_a * b
48+
assert not a.stop or idx < a.stop, f"Slice {b} out of range for {b}"
49+
return idx
50+
51+
elif isinstance(b, slice):
52+
53+
start_b = b.start if b.start else 0
54+
step_b = b.step if b.step else 1
55+
56+
start = start_a + step_a * start_b if a.start or b.start else None
57+
stop = step_a * b.stop if b.stop else a.stop
58+
step = step_a * step_b if a.step or b.step else None
59+
60+
return slice(start, stop, step)
61+
62+
elif isinstance(b, list):
63+
64+
return list(_chain_slice(a, x) for x in b)
65+
66+
elif isinstance(b, np.ndarray):
67+
68+
# is b a mask array?
69+
if b.dtype == bool:
70+
raise RuntimeError("Not yet implemented")
71+
72+
return np.array([_chain_slice(a, x) for x in b])
73+
74+
else:
75+
76+
raise RuntimeError(
77+
f"Don't know how to deal with slice {b} of type {type(b)}"
78+
)
79+
80+
# is an index array
81+
elif isinstance(a, list):
82+
83+
return list(np.array(a)[(b,)])
84+
85+
elif isinstance(a, np.ndarray):
86+
87+
if a.dtype == bool:
88+
raise RuntimeError("Not yet implemented")
89+
90+
return a[(b,)]
91+
92+
else:
93+
94+
raise RuntimeError(f"Don't know how to deal with slice {a} of type {type(a)}")

tests/test_slices.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
from funlib.persistence.arrays.slices import chain_slices
3+
4+
5+
def test_slice_chaining():
6+
7+
base = np.s_[::2, 0, :4]
8+
9+
# chain with index expressions
10+
11+
s1 = chain_slices(base, np.s_[0])
12+
assert s1 == np.s_[0, 0, :4]
13+
14+
s2 = chain_slices(s1, np.s_[1])
15+
assert s2 == np.s_[0, 0, 1]
16+
17+
# chain with index arrays
18+
19+
s1 = chain_slices(base, np.s_[[0, 1, 1, 2, 3, 5], :])
20+
assert s1 == np.s_[[0, 2, 2, 4, 6, 10], 0, :4]
21+
22+
# ...and another index array
23+
s21 = chain_slices(s1, np.s_[[0, 3], :])
24+
assert s21 == np.s_[[0, 4], 0, :4]
25+
26+
# ...and a slice() expression
27+
s22 = chain_slices(s1, np.s_[1:4])
28+
assert s22 == np.s_[[2, 2, 4], 0, :4]
29+
30+
# chain with slice expressions
31+
32+
s1 = chain_slices(base, np.s_[10:20, ::2])
33+
assert s1 == np.s_[20:40:2, 0, :4:2]

0 commit comments

Comments
 (0)