Skip to content

Commit

Permalink
make numpy_util.match work for non-integer inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
esheldon committed Aug 20, 2024
1 parent 415974a commit 444cbf6
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 24 deletions.
52 changes: 28 additions & 24 deletions esutil/numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,40 +1509,44 @@ def rem_dup(arr, flag, values=False):

def match(arr1input, arr2input, presorted=False):
"""
NAME:
match
Match two arrays, returning the indicies of matches for each array, or
empty arrays if no matches are found. This means arr1[ind1] == arr2[ind2]
is true for all corresponding pairs.
CALLING SEQUENCE:
ind1,ind2 = match(arr1, arr2, presorted=False)
arr1 must contain only unique inputs, but arr2 may be non-unique.
PURPOSE:
Match two numpy arrays. Return the indices of the matches or empty
arrays if no matches are found. This means arr1[ind1] == arr2[ind2] is
true for all corresponding pairs. arr1 must contain only unique
inputs, but arr2 may be non-unique.
If you know arr1 is sorted, set presorted=True and it will run
even faster
If you know arr1 is sorted, set presorted=True and it will run even faster
METHOD:
uses searchsorted with some sugar. Much faster than old version
based on IDL code.
REVISION HISTORY:
Created 2015, Eli Rykoff, SLAC.
Parameters
----------
arr1: array
The first array, which must have unique elements.
arr2: array
The second array.
presorted: bool, optional
If set to True, the first array is assumed to be sorted.
Returns
-------
ind1, ind2: array, array
The index arrays of matches for each array
Revision history
-----------------
Created 2015, Eli Rykoff, SLAC.
"""

# make sure 1D
arr1 = np.atleast_1d(arr1input)
arr2 = np.atleast_1d(arr2input)

# check for integer data...
if not issubclass(arr1.dtype.type, np.integer) or not issubclass(
arr2.dtype.type, np.integer
):
mess = "Error: only works with integer types, got %s %s"
mess = mess % (arr1.dtype.type, arr2.dtype.type)
raise ValueError(mess)
el = arr1input[0]

if isinstance(el, str) or isinstance(el, bytes):
is_string = True
else:
is_string = False

if (arr1.size == 0) or (arr2.size == 0):
mess = "Error: arr1 and arr2 must each be non-zero length"
Expand All @@ -1563,7 +1567,7 @@ def match(arr1input, arr2input, presorted=False):
sub1 = np.searchsorted(arr1, arr2, sorter=st1)

# check for out-of-bounds at the high end if necessary
if arr2.max() > arr1.max():
if is_string or arr2.max() > arr1.max():
(bad,) = np.where(sub1 == arr1.size)
sub1[bad] = arr1.size - 1

Expand Down
56 changes: 56 additions & 0 deletions esutil/tests/test_numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,59 @@ def test_split_array():
assert np.all(chunks[6] == [18, 19, 20])
assert np.all(chunks[7] == [21, 22, 23])
assert np.all(chunks[8] == [24])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_int(presorted):
a1 = np.array([3, 10, 8, 4, 7])
a2 = np.array([8, 3])

if not presorted:
ind = np.array([4, 1, 0, 2, 3])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
assert np.all(m1 == [3, 2])
else:
m1, m2 = eu.numpy_util.match(a1, a2)
assert np.all(m1 == [2, 0])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_float(presorted):
a1 = np.array([1.25, 6.61, 8.51, 9.91, 11.25])
a2 = np.array([6.61, 9.91])

if not presorted:
ind = np.array([4, 1, 0, 2, 3])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
assert np.all(m1 == [1, 4])
else:
m1, m2 = eu.numpy_util.match(a1, a2)
assert np.all(m1 == [1, 3])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_str(presorted):
a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things'])
a2 = np.array(['goodbye', 'things', 'zz'])

if not presorted:
ind = np.array([3, 4, 0, 2, 1])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
assert np.all(m1 == [4, 1])
else:
m1, m2 = eu.numpy_util.match(a1, a2)
assert np.all(m1 == [1, 4])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_none(presorted):
a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things'])
a2 = np.array(['zz', 'bb'])

if not presorted:
ind = np.array([3, 4, 0, 2, 1])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
else:
m1, m2 = eu.numpy_util.match(a1, a2)

assert m1.size == 0 and m2.size == 0

0 comments on commit 444cbf6

Please sign in to comment.