Skip to content

Commit

Permalink
Merge pull request #95 from esheldon/sstr
Browse files Browse the repository at this point in the history
make numpy_util.match work for non-integer inputs
  • Loading branch information
esheldon authored Aug 20, 2024
2 parents 415974a + 6273689 commit e2eca31
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 26 deletions.
7 changes: 7 additions & 0 deletions RELEASE_NOTES
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
0.6.15 (not yet released)
------

Enhancements

- numpy_util.match works for non-integer data types

0.6.14
------

Expand Down
2 changes: 1 addition & 1 deletion esutil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class for gauss-legendre integration, which relies on the gauleg C++ extension.

import sys

__version__ = "0.6.14"
__version__ = "0.6.15"

def version():
return __version__
Expand Down
55 changes: 31 additions & 24 deletions esutil/numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,40 +1509,47 @@ def rem_dup(arr, flag, values=False):

def match(arr1input, arr2input, presorted=False):
"""
NAME:
match
Match two arrays, returning the indicies of matches for each array, or
empty arrays if no matches are found. This means arr1[ind1] == arr2[ind2]
is true for all corresponding pairs. For floating-point data this implies
exact matching with no floating-point tolerance.
CALLING SEQUENCE:
ind1,ind2 = match(arr1, arr2, presorted=False)
The data type can be int, float, string or bytes.
PURPOSE:
Match two numpy arrays. Return the indices of the matches or empty
arrays if no matches are found. This means arr1[ind1] == arr2[ind2] is
true for all corresponding pairs. arr1 must contain only unique
inputs, but arr2 may be non-unique.
If you know arr1 is sorted, set presorted=True and it will run
even faster
arr1 must contain only unique inputs, but arr2 may be non-unique.
METHOD:
uses searchsorted with some sugar. Much faster than old version
based on IDL code.
If you know arr1 is sorted, set presorted=True and it will run even faster
REVISION HISTORY:
Created 2015, Eli Rykoff, SLAC.
Parameters
----------
arr1: array
The first array, which must have unique elements.
arr2: array
The second array.
presorted: bool, optional
If set to True, the first array is assumed to be sorted.
Returns
-------
ind1, ind2: array, array
The index arrays of matches for each array
Revision history
-----------------
Created 2015, Eli Rykoff, SLAC.
"""

# make sure 1D
arr1 = np.atleast_1d(arr1input)
arr2 = np.atleast_1d(arr2input)

# check for integer data...
if not issubclass(arr1.dtype.type, np.integer) or not issubclass(
arr2.dtype.type, np.integer
):
mess = "Error: only works with integer types, got %s %s"
mess = mess % (arr1.dtype.type, arr2.dtype.type)
raise ValueError(mess)
el = arr1input[0]

if isinstance(el, str) or isinstance(el, bytes):
is_string = True
else:
is_string = False

if (arr1.size == 0) or (arr2.size == 0):
mess = "Error: arr1 and arr2 must each be non-zero length"
Expand All @@ -1563,7 +1570,7 @@ def match(arr1input, arr2input, presorted=False):
sub1 = np.searchsorted(arr1, arr2, sorter=st1)

# check for out-of-bounds at the high end if necessary
if arr2.max() > arr1.max():
if is_string or arr2.max() > arr1.max():
(bad,) = np.where(sub1 == arr1.size)
sub1[bad] = arr1.size - 1

Expand Down
56 changes: 56 additions & 0 deletions esutil/tests/test_numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,59 @@ def test_split_array():
assert np.all(chunks[6] == [18, 19, 20])
assert np.all(chunks[7] == [21, 22, 23])
assert np.all(chunks[8] == [24])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_int(presorted):
a1 = np.array([3, 10, 8, 4, 7])
a2 = np.array([8, 3])

if not presorted:
ind = np.array([4, 1, 0, 2, 3])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
assert np.all(m1 == [3, 2])
else:
m1, m2 = eu.numpy_util.match(a1, a2)
assert np.all(m1 == [2, 0])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_float(presorted):
a1 = np.array([1.25, 6.61, 8.51, 9.91, 11.25])
a2 = np.array([6.61, 9.91])

if not presorted:
ind = np.array([4, 1, 0, 2, 3])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
assert np.all(m1 == [1, 4])
else:
m1, m2 = eu.numpy_util.match(a1, a2)
assert np.all(m1 == [1, 3])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_str(presorted):
a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things'])
a2 = np.array(['goodbye', 'things', 'zz'])

if not presorted:
ind = np.array([3, 4, 0, 2, 1])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
assert np.all(m1 == [4, 1])
else:
m1, m2 = eu.numpy_util.match(a1, a2)
assert np.all(m1 == [1, 4])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_nomatch(presorted):
a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things'])
a2 = np.array(['zz', 'bb'])

if not presorted:
ind = np.array([3, 4, 0, 2, 1])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
else:
m1, m2 = eu.numpy_util.match(a1, a2)

assert m1.size == 0 and m2.size == 0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def build_extensions(self):

setup(
name="esutil",
version="0.6.14",
version="0.6.15",
author="Erin Scott Sheldon",
author_email="erin.sheldon@gmail.com",
classifiers=classifiers,
Expand Down

0 comments on commit e2eca31

Please sign in to comment.