Skip to content

Commit 58f6f64

Browse files
committed
add isplit algorithm
Get indices to split a sequence into a number of chunks. This algorithm produces nearly equal chunks when the data cannot be equally split. Based on the algorithm for splitting arrays in numpy.array_split
1 parent ba4ee98 commit 58f6f64

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

esutil/algorithm.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,55 @@
2121
"""
2222

2323

24+
def isplit(num, nchunks):
25+
"""
26+
Get indices to split a sequence into a number of chunks. This algorithm
27+
produces nearly equal chunks when the data cannot be equally split.
28+
29+
Based on the algorithm for splitting arrays in numpy.array_split
30+
31+
Parameters
32+
----------
33+
num: int
34+
Number of elements to be split into chunks
35+
nchunks: int
36+
Number of chunks
37+
38+
Returns
39+
-------
40+
subs: array
41+
Array with fields 'start' and 'end for each chunks, so a chunk can
42+
be gotten with
43+
44+
subs = isplit(arr.size, nchunks)
45+
for i in range(nchunks):
46+
achunk = array[subs['start'][i]:subs['end'][i]]
47+
48+
"""
49+
import numpy as np
50+
51+
nchunks = int(nchunks)
52+
53+
if nchunks <= 0:
54+
raise ValueError(f'got nchunks={nchunks} < 0')
55+
56+
neach_section, extras = divmod(num, nchunks)
57+
58+
section_sizes = (
59+
[0] + extras * [neach_section+1]
60+
+ (nchunks-extras) * [neach_section]
61+
)
62+
div_points = np.array(section_sizes, dtype=np.intp).cumsum()
63+
64+
subs = np.zeros(nchunks, dtype=[('start', 'i8'), ('end', 'i8')])
65+
66+
for i in range(nchunks):
67+
subs['start'][i] = div_points[i]
68+
subs['end'][i] = div_points[i + 1]
69+
70+
return subs
71+
72+
2473
def quicksort(data):
2574
"""
2675
Name:

esutil/tests/test_algorithms.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,19 @@ def test_quicksort():
1313
s = data_orig.argsort()
1414

1515
assert np.all(data == data_orig[s])
16+
17+
18+
def test_isplit():
19+
num = 135
20+
nchunks = 11
21+
22+
subs = eu.algorithm.isplit(num=num, nchunks=nchunks)
23+
assert subs.size == 11
24+
25+
assert np.all(
26+
subs['start'] == [0, 13, 26, 39, 51, 63, 75, 87, 99, 111, 123]
27+
)
28+
29+
assert np.all(
30+
subs['end'] == [13, 26, 39, 51, 63, 75, 87, 99, 111, 123, 135]
31+
)

0 commit comments

Comments
 (0)