diff --git a/RELEASE_NOTES b/RELEASE_NOTES index d22f059..2819459 100644 --- a/RELEASE_NOTES +++ b/RELEASE_NOTES @@ -1,3 +1,10 @@ +0.6.15 (not yet released) +------ + +Enhancements + + - numpy_util.match works for non-integer data types + 0.6.14 ------ diff --git a/esutil/__init__.py b/esutil/__init__.py index bf7de63..b0287c8 100644 --- a/esutil/__init__.py +++ b/esutil/__init__.py @@ -83,7 +83,7 @@ class for gauss-legendre integration, which relies on the gauleg C++ extension. import sys -__version__ = "0.6.14" +__version__ = "0.6.15" def version(): return __version__ diff --git a/esutil/numpy_util.py b/esutil/numpy_util.py index 608a9cf..d19ebc8 100644 --- a/esutil/numpy_util.py +++ b/esutil/numpy_util.py @@ -1509,40 +1509,47 @@ def rem_dup(arr, flag, values=False): def match(arr1input, arr2input, presorted=False): """ - NAME: - match + Match two arrays, returning the indicies of matches for each array, or + empty arrays if no matches are found. This means arr1[ind1] == arr2[ind2] + is true for all corresponding pairs. For floating-point data this implies + exact matching with no floating-point tolerance. - CALLING SEQUENCE: - ind1,ind2 = match(arr1, arr2, presorted=False) + The data type can be int, float, string or bytes. - PURPOSE: - Match two numpy arrays. Return the indices of the matches or empty - arrays if no matches are found. This means arr1[ind1] == arr2[ind2] is - true for all corresponding pairs. arr1 must contain only unique - inputs, but arr2 may be non-unique. - If you know arr1 is sorted, set presorted=True and it will run - even faster + arr1 must contain only unique inputs, but arr2 may be non-unique. - METHOD: - uses searchsorted with some sugar. Much faster than old version - based on IDL code. + If you know arr1 is sorted, set presorted=True and it will run even faster - REVISION HISTORY: - Created 2015, Eli Rykoff, SLAC. + Parameters + ---------- + arr1: array + The first array, which must have unique elements. + arr2: array + The second array. + presorted: bool, optional + If set to True, the first array is assumed to be sorted. + + Returns + ------- + ind1, ind2: array, array + The index arrays of matches for each array + + Revision history + ----------------- + Created 2015, Eli Rykoff, SLAC. """ # make sure 1D arr1 = np.atleast_1d(arr1input) arr2 = np.atleast_1d(arr2input) - # check for integer data... - if not issubclass(arr1.dtype.type, np.integer) or not issubclass( - arr2.dtype.type, np.integer - ): - mess = "Error: only works with integer types, got %s %s" - mess = mess % (arr1.dtype.type, arr2.dtype.type) - raise ValueError(mess) + el = arr1input[0] + + if isinstance(el, str) or isinstance(el, bytes): + is_string = True + else: + is_string = False if (arr1.size == 0) or (arr2.size == 0): mess = "Error: arr1 and arr2 must each be non-zero length" @@ -1563,7 +1570,7 @@ def match(arr1input, arr2input, presorted=False): sub1 = np.searchsorted(arr1, arr2, sorter=st1) # check for out-of-bounds at the high end if necessary - if arr2.max() > arr1.max(): + if is_string or arr2.max() > arr1.max(): (bad,) = np.where(sub1 == arr1.size) sub1[bad] = arr1.size - 1 diff --git a/esutil/tests/test_numpy_util.py b/esutil/tests/test_numpy_util.py index 90cebad..9f2fa60 100644 --- a/esutil/tests/test_numpy_util.py +++ b/esutil/tests/test_numpy_util.py @@ -58,3 +58,59 @@ def test_split_array(): assert np.all(chunks[6] == [18, 19, 20]) assert np.all(chunks[7] == [21, 22, 23]) assert np.all(chunks[8] == [24]) + + +@pytest.mark.parametrize('presorted', [True, False]) +def test_match_int(presorted): + a1 = np.array([3, 10, 8, 4, 7]) + a2 = np.array([8, 3]) + + if not presorted: + ind = np.array([4, 1, 0, 2, 3]) + m1, m2 = eu.numpy_util.match(a1[ind], a2) + assert np.all(m1 == [3, 2]) + else: + m1, m2 = eu.numpy_util.match(a1, a2) + assert np.all(m1 == [2, 0]) + + +@pytest.mark.parametrize('presorted', [True, False]) +def test_match_float(presorted): + a1 = np.array([1.25, 6.61, 8.51, 9.91, 11.25]) + a2 = np.array([6.61, 9.91]) + + if not presorted: + ind = np.array([4, 1, 0, 2, 3]) + m1, m2 = eu.numpy_util.match(a1[ind], a2) + assert np.all(m1 == [1, 4]) + else: + m1, m2 = eu.numpy_util.match(a1, a2) + assert np.all(m1 == [1, 3]) + + +@pytest.mark.parametrize('presorted', [True, False]) +def test_match_str(presorted): + a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things']) + a2 = np.array(['goodbye', 'things', 'zz']) + + if not presorted: + ind = np.array([3, 4, 0, 2, 1]) + m1, m2 = eu.numpy_util.match(a1[ind], a2) + assert np.all(m1 == [4, 1]) + else: + m1, m2 = eu.numpy_util.match(a1, a2) + assert np.all(m1 == [1, 4]) + + +@pytest.mark.parametrize('presorted', [True, False]) +def test_match_nomatch(presorted): + a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things']) + a2 = np.array(['zz', 'bb']) + + if not presorted: + ind = np.array([3, 4, 0, 2, 1]) + m1, m2 = eu.numpy_util.match(a1[ind], a2) + else: + m1, m2 = eu.numpy_util.match(a1, a2) + + assert m1.size == 0 and m2.size == 0 diff --git a/setup.py b/setup.py index 1e5fe6d..c5dae01 100644 --- a/setup.py +++ b/setup.py @@ -244,7 +244,7 @@ def build_extensions(self): setup( name="esutil", - version="0.6.14", + version="0.6.15", author="Erin Scott Sheldon", author_email="erin.sheldon@gmail.com", classifiers=classifiers,