diff --git a/esutil/numpy_util.py b/esutil/numpy_util.py index 608a9cf..8717549 100644 --- a/esutil/numpy_util.py +++ b/esutil/numpy_util.py @@ -1509,40 +1509,44 @@ def rem_dup(arr, flag, values=False): def match(arr1input, arr2input, presorted=False): """ - NAME: - match + Match two arrays, returning the indicies of matches for each array, or + empty arrays if no matches are found. This means arr1[ind1] == arr2[ind2] + is true for all corresponding pairs. - CALLING SEQUENCE: - ind1,ind2 = match(arr1, arr2, presorted=False) + arr1 must contain only unique inputs, but arr2 may be non-unique. - PURPOSE: - Match two numpy arrays. Return the indices of the matches or empty - arrays if no matches are found. This means arr1[ind1] == arr2[ind2] is - true for all corresponding pairs. arr1 must contain only unique - inputs, but arr2 may be non-unique. - If you know arr1 is sorted, set presorted=True and it will run - even faster + If you know arr1 is sorted, set presorted=True and it will run even faster - METHOD: - uses searchsorted with some sugar. Much faster than old version - based on IDL code. - REVISION HISTORY: - Created 2015, Eli Rykoff, SLAC. + Parameters + ---------- + arr1: array + The first array, which must have unique elements. + arr2: array + The second array. + presorted: bool, optional + If set to True, the first array is assumed to be sorted. + Returns + ------- + ind1, ind2: array, array + The index arrays of matches for each array + + Revision history + ----------------- + Created 2015, Eli Rykoff, SLAC. """ # make sure 1D arr1 = np.atleast_1d(arr1input) arr2 = np.atleast_1d(arr2input) - # check for integer data... - if not issubclass(arr1.dtype.type, np.integer) or not issubclass( - arr2.dtype.type, np.integer - ): - mess = "Error: only works with integer types, got %s %s" - mess = mess % (arr1.dtype.type, arr2.dtype.type) - raise ValueError(mess) + el = arr1input[0] + + if isinstance(el, str) or isinstance(el, bytes): + is_string = True + else: + is_string = False if (arr1.size == 0) or (arr2.size == 0): mess = "Error: arr1 and arr2 must each be non-zero length" @@ -1563,7 +1567,7 @@ def match(arr1input, arr2input, presorted=False): sub1 = np.searchsorted(arr1, arr2, sorter=st1) # check for out-of-bounds at the high end if necessary - if arr2.max() > arr1.max(): + if is_string or arr2.max() > arr1.max(): (bad,) = np.where(sub1 == arr1.size) sub1[bad] = arr1.size - 1 diff --git a/esutil/tests/test_numpy_util.py b/esutil/tests/test_numpy_util.py index 90cebad..ff4571d 100644 --- a/esutil/tests/test_numpy_util.py +++ b/esutil/tests/test_numpy_util.py @@ -58,3 +58,59 @@ def test_split_array(): assert np.all(chunks[6] == [18, 19, 20]) assert np.all(chunks[7] == [21, 22, 23]) assert np.all(chunks[8] == [24]) + + +@pytest.mark.parametrize('presorted', [True, False]) +def test_match_int(presorted): + a1 = np.array([3, 10, 8, 4, 7]) + a2 = np.array([8, 3]) + + if not presorted: + ind = np.array([4, 1, 0, 2, 3]) + m1, m2 = eu.numpy_util.match(a1[ind], a2) + assert np.all(m1 == [3, 2]) + else: + m1, m2 = eu.numpy_util.match(a1, a2) + assert np.all(m1 == [2, 0]) + + +@pytest.mark.parametrize('presorted', [True, False]) +def test_match_float(presorted): + a1 = np.array([1.25, 6.61, 8.51, 9.91, 11.25]) + a2 = np.array([6.61, 9.91]) + + if not presorted: + ind = np.array([4, 1, 0, 2, 3]) + m1, m2 = eu.numpy_util.match(a1[ind], a2) + assert np.all(m1 == [1, 4]) + else: + m1, m2 = eu.numpy_util.match(a1, a2) + assert np.all(m1 == [1, 3]) + + +@pytest.mark.parametrize('presorted', [True, False]) +def test_match_str(presorted): + a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things']) + a2 = np.array(['goodbye', 'things', 'zz']) + + if not presorted: + ind = np.array([3, 4, 0, 2, 1]) + m1, m2 = eu.numpy_util.match(a1[ind], a2) + assert np.all(m1 == [4, 1]) + else: + m1, m2 = eu.numpy_util.match(a1, a2) + assert np.all(m1 == [1, 4]) + + +@pytest.mark.parametrize('presorted', [True, False]) +def test_match_none(presorted): + a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things']) + a2 = np.array(['zz', 'bb']) + + if not presorted: + ind = np.array([3, 4, 0, 2, 1]) + m1, m2 = eu.numpy_util.match(a1[ind], a2) + else: + m1, m2 = eu.numpy_util.match(a1, a2) + + assert m1.size == 0 and m2.size == 0