diff --git a/RELEASE_NOTES b/RELEASE_NOTES index bc9e384..8a27ba4 100644 --- a/RELEASE_NOTES +++ b/RELEASE_NOTES @@ -3,8 +3,10 @@ Enhancements - - add simple option to pbar to spam log files less + - add algorithms.isplit function, get indices to split + a sequence into chunks - numpy_util.match works for non-integer data types + - add simple option to pbar to spam log files less 0.6.14 ------ diff --git a/esutil/algorithm.py b/esutil/algorithm.py index ca471d3..9514210 100644 --- a/esutil/algorithm.py +++ b/esutil/algorithm.py @@ -21,6 +21,55 @@ """ +def isplit(num, nchunks): + """ + Get indices to split a sequence into a number of chunks. This algorithm + produces nearly equal chunks when the data cannot be equally split. + + Based on the algorithm for splitting arrays in numpy.array_split + + Parameters + ---------- + num: int + Number of elements to be split into chunks + nchunks: int + Number of chunks + + Returns + ------- + subs: array + Array with fields 'start' and 'end for each chunks, so a chunk can + be gotten with + + subs = isplit(arr.size, nchunks) + for i in range(nchunks): + achunk = array[subs['start'][i]:subs['end'][i]] + + """ + import numpy as np + + nchunks = int(nchunks) + + if nchunks <= 0: + raise ValueError(f'got nchunks={nchunks} < 0') + + neach_section, extras = divmod(num, nchunks) + + section_sizes = ( + [0] + extras * [neach_section+1] + + (nchunks-extras) * [neach_section] + ) + div_points = np.array(section_sizes, dtype=np.intp).cumsum() + + subs = np.zeros(nchunks, dtype=[('start', 'i8'), ('end', 'i8')]) + + for i in range(nchunks): + subs['start'][i] = div_points[i] + subs['end'][i] = div_points[i + 1] + + return subs + + def quicksort(data): """ Name: diff --git a/esutil/tests/test_algorithms.py b/esutil/tests/test_algorithms.py index 9e54234..316e02a 100644 --- a/esutil/tests/test_algorithms.py +++ b/esutil/tests/test_algorithms.py @@ -13,3 +13,19 @@ def test_quicksort(): s = data_orig.argsort() assert np.all(data == data_orig[s]) + + +def test_isplit(): + num = 135 + nchunks = 11 + + subs = eu.algorithm.isplit(num=num, nchunks=nchunks) + assert subs.size == 11 + + assert np.all( + subs['start'] == [0, 13, 26, 39, 51, 63, 75, 87, 99, 111, 123] + ) + + assert np.all( + subs['end'] == [13, 26, 39, 51, 63, 75, 87, 99, 111, 123, 135] + )