Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make numpy_util.match work for non-integer inputs #95

Merged
merged 6 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions RELEASE_NOTES
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
0.6.15 (not yet released)
------

Enhancements

- numpy_util.match works for non-integer data types

0.6.14
------

Expand Down
2 changes: 1 addition & 1 deletion esutil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class for gauss-legendre integration, which relies on the gauleg C++ extension.

import sys

__version__ = "0.6.14"
__version__ = "0.6.15"

def version():
return __version__
Expand Down
55 changes: 31 additions & 24 deletions esutil/numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,40 +1509,47 @@ def rem_dup(arr, flag, values=False):

def match(arr1input, arr2input, presorted=False):
"""
NAME:
match
Match two arrays, returning the indicies of matches for each array, or
empty arrays if no matches are found. This means arr1[ind1] == arr2[ind2]
is true for all corresponding pairs. For floating-point data this implies
exact matching with no floating-point tolerance.

CALLING SEQUENCE:
ind1,ind2 = match(arr1, arr2, presorted=False)
The data type can be int, float, string or bytes.

PURPOSE:
Match two numpy arrays. Return the indices of the matches or empty
arrays if no matches are found. This means arr1[ind1] == arr2[ind2] is
true for all corresponding pairs. arr1 must contain only unique
inputs, but arr2 may be non-unique.
If you know arr1 is sorted, set presorted=True and it will run
even faster
arr1 must contain only unique inputs, but arr2 may be non-unique.

METHOD:
uses searchsorted with some sugar. Much faster than old version
based on IDL code.
If you know arr1 is sorted, set presorted=True and it will run even faster

REVISION HISTORY:
Created 2015, Eli Rykoff, SLAC.

Parameters
----------
arr1: array
The first array, which must have unique elements.
arr2: array
The second array.
presorted: bool, optional
If set to True, the first array is assumed to be sorted.

Returns
-------
ind1, ind2: array, array
The index arrays of matches for each array

Revision history
-----------------
Created 2015, Eli Rykoff, SLAC.
"""

# make sure 1D
arr1 = np.atleast_1d(arr1input)
arr2 = np.atleast_1d(arr2input)

# check for integer data...
if not issubclass(arr1.dtype.type, np.integer) or not issubclass(
arr2.dtype.type, np.integer
):
mess = "Error: only works with integer types, got %s %s"
mess = mess % (arr1.dtype.type, arr2.dtype.type)
raise ValueError(mess)
el = arr1input[0]

if isinstance(el, str) or isinstance(el, bytes):
is_string = True
else:
is_string = False

if (arr1.size == 0) or (arr2.size == 0):
mess = "Error: arr1 and arr2 must each be non-zero length"
Expand All @@ -1563,7 +1570,7 @@ def match(arr1input, arr2input, presorted=False):
sub1 = np.searchsorted(arr1, arr2, sorter=st1)

# check for out-of-bounds at the high end if necessary
if arr2.max() > arr1.max():
if is_string or arr2.max() > arr1.max():
(bad,) = np.where(sub1 == arr1.size)
sub1[bad] = arr1.size - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't comment below on the PR because GH. But in the case of floating point inputs I don't think we want (sub2,) = np.where(arr1[st1[sub1]] == arr2) or (sub2,) = np.where(arr1[sub1] == arr2). Instead in these cases we need np.isclose() with some suitable defaults for rtol and atol and a way to override. (The numpy defaults for rtol and atol seem appropriate for 32-bit floats and not 64-bit doubles if that is relevant as well).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did mean this to be an exact match test. It is not a common use case, but not unheard of: you have written the same data out to multiple binary files and the only way to match them is through some fields you expect to match exactly

But maybe we could either add a keyword "close" or "inexact", or a separate function aimed at floating point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't comment on those lines because they weren't close enough to your changes. It's always been a GH review problem.

Anyway, either we (a) make it super clear that this has to be an exact floating point match, or (b) I think that adding a separate function or a keyword would make sense. Maybe that's a separate PR though, so if you just update the docstring now that would be sufficient.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc says: This means arr1[ind1] == arr2[ind2] is true for all corresponding pairs, is that sufficient?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that floating point data should be called out explicitly. E.g. For floating-point data this implies exact matching with no floating-point tolerance.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


Expand Down
56 changes: 56 additions & 0 deletions esutil/tests/test_numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,59 @@ def test_split_array():
assert np.all(chunks[6] == [18, 19, 20])
assert np.all(chunks[7] == [21, 22, 23])
assert np.all(chunks[8] == [24])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_int(presorted):
a1 = np.array([3, 10, 8, 4, 7])
a2 = np.array([8, 3])

if not presorted:
ind = np.array([4, 1, 0, 2, 3])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
assert np.all(m1 == [3, 2])
else:
m1, m2 = eu.numpy_util.match(a1, a2)
assert np.all(m1 == [2, 0])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_float(presorted):
a1 = np.array([1.25, 6.61, 8.51, 9.91, 11.25])
a2 = np.array([6.61, 9.91])

if not presorted:
ind = np.array([4, 1, 0, 2, 3])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
assert np.all(m1 == [1, 4])
else:
m1, m2 = eu.numpy_util.match(a1, a2)
assert np.all(m1 == [1, 3])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_str(presorted):
a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things'])
a2 = np.array(['goodbye', 'things', 'zz'])

if not presorted:
ind = np.array([3, 4, 0, 2, 1])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
assert np.all(m1 == [4, 1])
else:
m1, m2 = eu.numpy_util.match(a1, a2)
assert np.all(m1 == [1, 4])


@pytest.mark.parametrize('presorted', [True, False])
def test_match_nomatch(presorted):
a1 = np.array(['blah', 'goodbye', 'hello', 'stuff', 'things'])
a2 = np.array(['zz', 'bb'])

if not presorted:
ind = np.array([3, 4, 0, 2, 1])
m1, m2 = eu.numpy_util.match(a1[ind], a2)
else:
m1, m2 = eu.numpy_util.match(a1, a2)

assert m1.size == 0 and m2.size == 0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def build_extensions(self):

setup(
name="esutil",
version="0.6.14",
version="0.6.15",
author="Erin Scott Sheldon",
author_email="erin.sheldon@gmail.com",
classifiers=classifiers,
Expand Down
Loading