Skip to content

Commit

Permalink
Merge pull request #99 from esheldon/isplit
Browse files Browse the repository at this point in the history
add isplit algorithm
  • Loading branch information
esheldon authored Oct 29, 2024
2 parents ba4ee98 + 120c5f6 commit a395bad
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
4 changes: 3 additions & 1 deletion RELEASE_NOTES
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

Enhancements

- add simple option to pbar to spam log files less
- add algorithms.isplit function, get indices to split
a sequence into chunks
- numpy_util.match works for non-integer data types
- add simple option to pbar to spam log files less

0.6.14
------
Expand Down
49 changes: 49 additions & 0 deletions esutil/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,55 @@
"""


def isplit(num, nchunks):
"""
Get indices to split a sequence into a number of chunks. This algorithm
produces nearly equal chunks when the data cannot be equally split.
Based on the algorithm for splitting arrays in numpy.array_split
Parameters
----------
num: int
Number of elements to be split into chunks
nchunks: int
Number of chunks
Returns
-------
subs: array
Array with fields 'start' and 'end for each chunks, so a chunk can
be gotten with
subs = isplit(arr.size, nchunks)
for i in range(nchunks):
achunk = array[subs['start'][i]:subs['end'][i]]
"""
import numpy as np

nchunks = int(nchunks)

if nchunks <= 0:
raise ValueError(f'got nchunks={nchunks} < 0')

neach_section, extras = divmod(num, nchunks)

section_sizes = (
[0] + extras * [neach_section+1]
+ (nchunks-extras) * [neach_section]
)
div_points = np.array(section_sizes, dtype=np.intp).cumsum()

subs = np.zeros(nchunks, dtype=[('start', 'i8'), ('end', 'i8')])

for i in range(nchunks):
subs['start'][i] = div_points[i]
subs['end'][i] = div_points[i + 1]

return subs


def quicksort(data):
"""
Name:
Expand Down
16 changes: 16 additions & 0 deletions esutil/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,19 @@ def test_quicksort():
s = data_orig.argsort()

assert np.all(data == data_orig[s])


def test_isplit():
num = 135
nchunks = 11

subs = eu.algorithm.isplit(num=num, nchunks=nchunks)
assert subs.size == 11

assert np.all(
subs['start'] == [0, 13, 26, 39, 51, 63, 75, 87, 99, 111, 123]
)

assert np.all(
subs['end'] == [13, 26, 39, 51, 63, 75, 87, 99, 111, 123, 135]
)

0 comments on commit a395bad

Please sign in to comment.