diff --git a/astropy/coordinates/tests/test_matching.py b/astropy/coordinates/tests/test_matching.py index 91b6731a6cc3..324e9984e834 100644 --- a/astropy/coordinates/tests/test_matching.py +++ b/astropy/coordinates/tests/test_matching.py @@ -5,22 +5,34 @@ from numpy import testing as npt from astropy import units as u -from astropy.coordinates import matching +from astropy.coordinates import ( + ICRS, + Angle, + CartesianRepresentation, + Galactic, + SkyCoord, + match_coordinates_3d, + match_coordinates_sky, + search_around_3d, + search_around_sky, +) from astropy.tests.helper import assert_quantity_allclose as assert_allclose +from astropy.utils import NumpyRNGContext from astropy.utils.compat.optional_deps import HAS_SCIPY """ These are the tests for coordinate matching. -Note that this requires scipy. +Coordinate matching can involve caching, so it is best to recreate the +coordinate objects in every test instead of trying to reuse module-level +variables. """ +if not HAS_SCIPY: + pytest.skip("Coordinate matching requires scipy", allow_module_level=True) -@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.") -def test_matching_function(): - from astropy.coordinates import ICRS - from astropy.coordinates.matching import match_coordinates_3d +def test_matching_function(): # this only uses match_coordinates_3d because that's the actual implementation cmatch = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree) @@ -37,11 +49,7 @@ def test_matching_function(): npt.assert_array_less(d3d.value, 0.02) -@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.") def test_matching_function_3d_and_sky(): - from astropy.coordinates import ICRS - from astropy.coordinates.matching import match_coordinates_3d, match_coordinates_sky - cmatch = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 5] * u.kpc) ccatalog = ICRS( [1, 2, 3, 4] * u.degree, [0, 0, 0, 0] * u.degree, distance=[1, 1, 1, 5] * u.kpc @@ -64,16 +72,13 @@ def test_matching_function_3d_and_sky(): @pytest.mark.parametrize( "functocheck, args, defaultkdtname, bothsaved", [ - (matching.match_coordinates_3d, [], "kdtree_3d", False), - (matching.match_coordinates_sky, [], "kdtree_sky", False), - (matching.search_around_3d, [1 * u.kpc], "kdtree_3d", True), - (matching.search_around_sky, [1 * u.deg], "kdtree_sky", False), + (match_coordinates_3d, [], "kdtree_3d", False), + (match_coordinates_sky, [], "kdtree_sky", False), + (search_around_3d, [1 * u.kpc], "kdtree_3d", True), + (search_around_sky, [1 * u.deg], "kdtree_sky", False), ], ) -@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.") def test_kdtree_storage(functocheck, args, defaultkdtname, bothsaved): - from astropy.coordinates import ICRS - def make_scs(): cmatch = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 2] * u.kpc) ccatalog = ICRS( @@ -119,12 +124,7 @@ def make_scs(): assert "KD" in e.value.args[0] -@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.") def test_matching_method(): - from astropy.coordinates import ICRS, SkyCoord - from astropy.coordinates.matching import match_coordinates_3d, match_coordinates_sky - from astropy.utils import NumpyRNGContext - with NumpyRNGContext(987654321): cmatch = ICRS( np.random.rand(20) * 360.0 * u.degree, @@ -153,98 +153,150 @@ def test_matching_method(): assert len(idx1) == len(d2d1) == len(d3d1) == 20 -@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy") -def test_search_around(): - from astropy.coordinates import ICRS, SkyCoord - from astropy.coordinates.matching import search_around_3d, search_around_sky - - coo1 = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 5] * u.kpc) - coo2 = ICRS( - [1, 2, 3, 4] * u.degree, [0, 0, 0, 0] * u.degree, distance=[1, 1, 1, 5] * u.kpc - ) - - idx1_1deg, idx2_1deg, d2d_1deg, d3d_1deg = search_around_sky( - coo1, coo2, 1.01 * u.deg - ) - idx1_0p05deg, idx2_0p05deg, d2d_0p05deg, d3d_0p05deg = search_around_sky( - coo1, coo2, 0.05 * u.deg +@pytest.mark.parametrize( + "search_limit,expected_idx1,expected_idx2,expected_d2d,expected_d3d", + [ + pytest.param( + 1.01 * u.deg, + [0, 0, 1, 1], + [2, 3, 1, 2], + [1, 0, 0.1, 0.9] * u.deg, + [0.01745307, 4.0, 4.0000019, 4.00015421] * u.kpc, + id="1.01_deg", + ), + pytest.param(0.05 * u.deg, [0], [3], [0] * u.deg, [4] * u.kpc, id="0.05_deg"), + ], +) +def test_search_around_sky( + search_limit, expected_idx1, expected_idx2, expected_d2d, expected_d3d +): + idx1, idx2, d2d, d3d = search_around_sky( + ICRS([4, 2.1] * u.deg, [0, 0] * u.deg, distance=[1, 5] * u.kpc), + ICRS([1, 2, 3, 4] * u.deg, [0, 0, 0, 0] * u.deg, distance=[1, 1, 1, 5] * u.kpc), + search_limit, ) + npt.assert_array_equal(idx1, expected_idx1) + npt.assert_array_equal(idx2, expected_idx2) + assert_allclose(d2d, expected_d2d) + assert_allclose(d3d, expected_d3d) - assert list(zip(idx1_1deg, idx2_1deg)) == [(0, 2), (0, 3), (1, 1), (1, 2)] - assert_allclose(d2d_1deg[0], 1.0 * u.deg, atol=1e-14 * u.deg, rtol=0) - assert_allclose(d2d_1deg, [1, 0, 0.1, 0.9] * u.deg) - assert list(zip(idx1_0p05deg, idx2_0p05deg)) == [(0, 3)] - - idx1_1kpc, idx2_1kpc, d2d_1kpc, d3d_1kpc = search_around_3d(coo1, coo2, 1 * u.kpc) - idx1_sm, idx2_sm, d2d_sm, d3d_sm = search_around_3d(coo1, coo2, 0.05 * u.kpc) +@pytest.mark.parametrize( + "search_limit,expected_idx1,expected_idx2,expected_d2d,expected_d3d", + [ + pytest.param( + 1 * u.kpc, + [0, 0, 0, 1], + [0, 1, 2, 3], + [3, 2, 1, 1.9] * u.deg, + [0.0523539, 0.03490481, 0.01745307, 0.16579868] * u.kpc, + id="1_kpc", + ), + pytest.param( + 0.05 * u.kpc, + [0, 0], + [1, 2], + [2, 1] * u.deg, + [0.03490481, 0.01745307] * u.kpc, + id="0.05_kpc", + ), + ], +) +def test_search_around_3d( + search_limit, expected_idx1, expected_idx2, expected_d2d, expected_d3d +): + idx1, idx2, d2d, d3d = search_around_3d( + ICRS([4, 2.1] * u.deg, [0, 0] * u.deg, distance=[1, 5] * u.kpc), + ICRS([1, 2, 3, 4] * u.deg, [0, 0, 0, 0] * u.deg, distance=[1, 1, 1, 5] * u.kpc), + search_limit, + ) + npt.assert_array_equal(idx1, expected_idx1) + npt.assert_array_equal(idx2, expected_idx2) + assert_allclose(d2d, expected_d2d) + assert_allclose(d3d, expected_d3d) - assert list(zip(idx1_1kpc, idx2_1kpc)) == [(0, 0), (0, 1), (0, 2), (1, 3)] - assert list(zip(idx1_sm, idx2_sm)) == [(0, 1), (0, 2)] - assert_allclose(d2d_sm, [2, 1] * u.deg) +@pytest.mark.parametrize( + "function,search_limit", + [ + pytest.param(func, limit, id=func.__name__) + for func, limit in ([search_around_3d, 1 * u.m], [search_around_sky, 1 * u.deg]) + ], +) +def test_search_around_no_matches(function, search_limit): # Test for the non-matches, #4877 - coo1 = ICRS([4.1, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 5] * u.kpc) - idx1, idx2, d2d, d3d = search_around_sky(coo1, coo2, 1 * u.arcsec) - assert idx1.size == idx2.size == d2d.size == d3d.size == 0 - assert idx1.dtype == idx2.dtype == int - assert d2d.unit == u.deg - assert d3d.unit == u.kpc - idx1, idx2, d2d, d3d = search_around_3d(coo1, coo2, 1 * u.m) - assert idx1.size == idx2.size == d2d.size == d3d.size == 0 - assert idx1.dtype == idx2.dtype == int + idx1, idx2, d2d, d3d = function( + ICRS([41, 21] * u.deg, [0, 0] * u.deg, distance=[1, 5] * u.kpc), + ICRS([1, 2] * u.deg, [0, 0] * u.deg, distance=[1, 1] * u.kpc), + search_limit, + ) + assert idx1.size == 0 + assert idx2.size == 0 + assert d2d.size == 0 + assert d3d.size == 0 + assert idx1.dtype == int + assert idx2.dtype == int assert d2d.unit == u.deg assert d3d.unit == u.kpc + +@pytest.mark.parametrize( + "function,search_limit", + [ + pytest.param(func, limit, id=func.__name__) + for func, limit in ([search_around_3d, 1 * u.m], [search_around_sky, 1 * u.deg]) + ], +) +@pytest.mark.parametrize( + "sources,catalog", + [ + pytest.param( + ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc), + ICRS([1] * u.deg, [0] * u.deg, distance=[1] * u.kpc), + id="empty_sources", + ), + pytest.param( + ICRS([1] * u.deg, [0] * u.deg, distance=[1] * u.kpc), + ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc), + id="empty_catalog", + ), + pytest.param( + ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc), + ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc), + id="empty_both", + ), + ], +) +def test_search_around_empty_input(sources, catalog, function, search_limit): # Test when one or both of the coordinate arrays is empty, #4875 - empty = ICRS(ra=[] * u.degree, dec=[] * u.degree, distance=[] * u.kpc) - idx1, idx2, d2d, d3d = search_around_sky(empty, coo2, 1 * u.arcsec) - assert idx1.size == idx2.size == d2d.size == d3d.size == 0 - assert idx1.dtype == idx2.dtype == int - assert d2d.unit == u.deg - assert d3d.unit == u.kpc - idx1, idx2, d2d, d3d = search_around_sky(coo1, empty, 1 * u.arcsec) - assert idx1.size == idx2.size == d2d.size == d3d.size == 0 - assert idx1.dtype == idx2.dtype == int - assert d2d.unit == u.deg - assert d3d.unit == u.kpc - empty = ICRS(ra=[] * u.degree, dec=[] * u.degree, distance=[] * u.kpc) - idx1, idx2, d2d, d3d = search_around_sky(empty, empty[:], 1 * u.arcsec) - assert idx1.size == idx2.size == d2d.size == d3d.size == 0 - assert idx1.dtype == idx2.dtype == int - assert d2d.unit == u.deg - assert d3d.unit == u.kpc - idx1, idx2, d2d, d3d = search_around_3d(empty, coo2, 1 * u.m) - assert idx1.size == idx2.size == d2d.size == d3d.size == 0 - assert idx1.dtype == idx2.dtype == int - assert d2d.unit == u.deg - assert d3d.unit == u.kpc - idx1, idx2, d2d, d3d = search_around_3d(coo1, empty, 1 * u.m) - assert idx1.size == idx2.size == d2d.size == d3d.size == 0 - assert idx1.dtype == idx2.dtype == int - assert d2d.unit == u.deg - assert d3d.unit == u.kpc - idx1, idx2, d2d, d3d = search_around_3d(empty, empty[:], 1 * u.m) - assert idx1.size == idx2.size == d2d.size == d3d.size == 0 - assert idx1.dtype == idx2.dtype == int + idx1, idx2, d2d, d3d = function(sources, catalog, search_limit) + assert idx1.size == 0 + assert idx2.size == 0 + assert d2d.size == 0 + assert d3d.size == 0 + assert idx1.dtype == int + assert idx2.dtype == int assert d2d.unit == u.deg assert d3d.unit == u.kpc + +@pytest.mark.parametrize( + "function,search_limit", + [ + pytest.param(func, limit, id=func.__name__) + for func, limit in ([search_around_3d, 1 * u.m], [search_around_sky, 1 * u.deg]) + ], +) +def test_search_around_no_dist_input_output_units(function, search_limit): # Test that input without distance units results in a # 'dimensionless_unscaled' unit - cempty = SkyCoord(ra=[], dec=[], unit=u.deg) - idx1, idx2, d2d, d3d = search_around_3d(cempty, cempty[:], 1 * u.m) - assert d2d.unit == u.deg - assert d3d.unit == u.dimensionless_unscaled - idx1, idx2, d2d, d3d = search_around_sky(cempty, cempty[:], 1 * u.m) + empty_sc = SkyCoord([], [], unit=u.deg) + idx1, idx2, d2d, d3d = function(empty_sc, empty_sc[:], search_limit) assert d2d.unit == u.deg assert d3d.unit == u.dimensionless_unscaled -@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy") def test_search_around_scalar(): - from astropy.coordinates import Angle, SkyCoord - cat = SkyCoord([1, 2, 3], [-30, 45, 8], unit="deg") target = SkyCoord("1.1 -30.1", unit="deg") @@ -260,10 +312,7 @@ def test_search_around_scalar(): assert "search_around_3d" in str(excinfo.value) -@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy") def test_match_catalog_empty(): - from astropy.coordinates import SkyCoord - sc1 = SkyCoord(1, 2, unit="deg") cat0 = SkyCoord([], [], unit="deg") cat1 = SkyCoord([1.1], [2.1], unit="deg") @@ -290,11 +339,8 @@ def test_match_catalog_empty(): assert "catalog" in str(excinfo.value) -@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy") @pytest.mark.filterwarnings(r"ignore:invalid value encountered in.*:RuntimeWarning") def test_match_catalog_nan(): - from astropy.coordinates import Galactic, SkyCoord - sc1 = SkyCoord(1, 2, unit="deg") sc_with_nans = SkyCoord(1, np.nan, unit="deg") @@ -324,11 +370,7 @@ def test_match_catalog_nan(): assert "Matching coordinates cannot contain" in str(excinfo.value) -@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy") def test_match_catalog_nounit(): - from astropy.coordinates import ICRS, CartesianRepresentation - from astropy.coordinates.matching import match_coordinates_sky - i1 = ICRS([[1], [2], [3]], representation_type=CartesianRepresentation) i2 = ICRS([[1], [2], [4, 5]], representation_type=CartesianRepresentation) i, sep, sep3d = match_coordinates_sky(i1, i2)