From 52f21be65d9c7a2c8270e6f0dbd4575950e80826 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Fri, 23 Dec 2022 22:51:33 -0500 Subject: [PATCH] Fix bin size in 2D histogram used in matching (#173) * Fix bin size in 2D histogram used in matching * deprecate tp_wcs argument * Simplify code, improve docs, prepare release. * Add unit test. Fix warning message --- CHANGELOG.rst | 15 ++- tweakwcs/correctors.py | 5 + tweakwcs/matchutils.py | 134 +++++++++++++---------- tweakwcs/tests/test_matchutils.py | 36 +++--- tweakwcs/tests/test_multichip_fitswcs.py | 2 +- tweakwcs/tests/test_multichip_jwst.py | 2 +- tweakwcs/tests/test_wcsimage.py | 18 +++ tweakwcs/wcsimage.py | 31 +++++- 8 files changed, 170 insertions(+), 73 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ad1d53a..9c90eb0 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,9 +4,22 @@ Release Notes ============= -.. 0.8.1 (unreleased) +.. 0.8.2 (unreleased) ================== +0.8.1 (23-December-2022) +======================== + +- Fixed a bug in the ``XYXYMatch`` due to which bin size for the 2D histogram + pre-match alignment did not account for the pixel scale in the tangent plane. + This required a change in the API of ``XYXYMatch.__call__`` which now + _must_ have ``tp_pscale`` as input and also inputs catalogs now _must_ + contain ``'TPx'`` and ``'TPy'`` columns. [#173] + +- Deprecated ``'tp_wcs'`` argument of the ``XYXYMatch.__call__()`` method. + Use ``'tp_pscale'`` instead. [#173] + + 0.8.0 (25-August-2022) ====================== diff --git a/tweakwcs/correctors.py b/tweakwcs/correctors.py index eddaf5e..876bb77 100644 --- a/tweakwcs/correctors.py +++ b/tweakwcs/correctors.py @@ -74,6 +74,7 @@ class WCSCorrector(ABC): and for managing tangent-plane corrections. """ + units = None def __init__(self, wcs, meta=None): """ @@ -278,6 +279,8 @@ class FITSWCSCorrector(WCSCorrector): supported. """ + units = 'pixel' + def __init__(self, wcs, meta=None): """ Parameters @@ -557,6 +560,8 @@ class JWSTWCSCorrector(WCSCorrector): tangent-plane corrections. """ + units = 'arcsec' + def __init__(self, wcs, wcsinfo, meta=None): """ Parameters diff --git a/tweakwcs/matchutils.py b/tweakwcs/matchutils.py index 4d21a87..19bc7d0 100644 --- a/tweakwcs/matchutils.py +++ b/tweakwcs/matchutils.py @@ -7,11 +7,13 @@ """ import logging +import warnings from abc import ABC, abstractmethod import numpy as np import astropy from astropy.utils.decorators import deprecated +from astropy.utils.exceptions import AstropyDeprecationWarning from stsci.stimage import xyxymatch @@ -33,7 +35,7 @@ def __init__(self): """ @abstractmethod - def __call__(self, refcat, imcat): + def __call__(self, refcat, imcat, **kwargs): """ Performs catalog matching. Parameters @@ -51,6 +53,9 @@ def __call__(self, refcat, imcat): (distortion-correction applied) source coordinates coordinate system common (shared) with the image catalog ``refcat``. + **kwargs : dict + Any keyword arguments for ``__call__`` specific to subclass. + Returns ------- (refcat_idx, imcat_idx): tuple of numpy.ndarray @@ -138,41 +143,35 @@ def __init__(self, searchrad=3.0, separation=0.5, use2dhist=True, self._xoffset = float(xoffset) self._yoffset = float(yoffset) - def __call__(self, refcat, imcat, tp_wcs=None): + def __call__(self, refcat, imcat, tp_pscale=1.0, tp_units=None, **kwargs): r""" Performs catalog matching. Parameters ---------- refcat: astropy.table.Table - A reference source catalog. When a tangent-plane ``WCS`` is - provided through ``tp_wcs``, the catalog must contain ``'RA'`` and - ``'DEC'`` columns which indicate reference source world - coordinates (in degrees). Alternatively, when ``tp_wcs`` is `None`, - reference catalog must contain ``'TPx'`` and ``'TPy'`` columns that - provide undistorted (distortion-correction applied) source - coordinates in some *tangent plane*. In this case, the ``'RA'`` - and ``'DEC'`` columns in the ``refcat`` catalog will be ignored. + A reference source catalog. Reference catalog must contain + ``'TPx'`` and ``'TPy'`` columns that provide undistorted + (distortion-correction applied) source coordinates in some + *tangent plane*. The ``'RA'`` and ``'DEC'`` columns in the + ``refcat`` catalog will be ignored. imcat: astropy.table.Table - Source catalog associated with an image. Must contain ``'x'`` and - ``'y'`` columns which indicate source coordinates (in pixels) in - the associated image. Alternatively, when ``tp_wcs`` is `None`, - catalog must contain ``'TPx'`` and ``'TPy'`` columns that - provide undistorted (distortion-correction applied) source + Source catalog associated with an image. The catalog must contain + ``'TPx'`` and ``'TPy'`` columns that provide undistorted + (distortion-correction applied) source coordinates in **the same**\ *tangent plane* used to define ``refcat``'s tangent plane coordinates. In this case, the ``'x'`` and ``'y'`` columns in the ``imcat`` catalog will be ignored. - tp_wcs: WCSCorrector, None, optional - A ``WCS`` that defines a tangent plane onto which both - reference and image catalog sources can be projected. For this - reason, ``tp_wcs`` is associated with the image in which sources - from the ``imcat`` catalog were found in the sense that ``tp_wcs`` - must be able to map image coordinates ``'x'`` and ``'y'`` from the - ``imcat`` catalog to the tangent plane. When ``tp_wcs`` is - provided, the ``'TPx'`` and ``'TPy'`` columns in both ``imcat`` and - ``refcat`` catalogs will be ignored (if present). + tp_pscale: float + Pixel scale: size of an image pixel in the tangent plane. + Pixel scale is in the same units as the coordinates of the tangent + plane. Pixel scale is used to compute bin size used for + initial 2D histogram alignment performed before matching. + + tp_units: str + Units of the tangent plane coordinates. Returns ------- @@ -198,6 +197,16 @@ def __call__(self, refcat, imcat, tp_wcs=None): raise ValueError("Image catalog must contain at least one " "source.") + if 'tp_wcs' in kwargs: + warnings.warn( + "Argument 'tp_wcs' has been deprecated since version 0.8.1. " + "Please use 'tp_pscale' instead and populate 'TPx' and 'TPy' " + "columns of input catalogs.", + AstropyDeprecationWarning + ) + + tp_wcs = kwargs.get('tp_wcs') + if tp_wcs is None: if 'TPx' not in refcat.colnames or 'TPy' not in refcat.colnames: raise KeyError("When tangent plane WCS is not provided, " @@ -242,27 +251,26 @@ def __call__(self, refcat, imcat, tp_wcs=None): "reference '{:s}' catalog." .format(imcat_name, refcat_name)) - ps = 1.0 if tp_wcs is None else tp_wcs.tanp_center_pixel_scale - if self._use2dhist: # Determine xyoff (X,Y offset) and tolerance # to be used with xyxymatch: - zpxoff, zpyoff = _estimate_2dhist_shift( - imxy / ps, - refxy / ps, - searchrad=self._searchrad + xyoff = _estimate_2dhist_shift( + imxy, + refxy, + searchrad=self._searchrad, + pscale=tp_pscale, + units=tp_units ) - xyoff = (zpxoff * ps, zpyoff * ps) else: - xyoff = (self._xoffset * ps, self._yoffset * ps) + xyoff = (self._xoffset, self._yoffset) matches = xyxymatch( imxy, refxy, origin=xyoff, - tolerance=ps * self._tolerance, - separation=ps * self._separation + tolerance=self._tolerance, + separation=self._separation ) return matches['ref_idx'], matches['input_idx'] @@ -281,43 +289,57 @@ def _xy_2dhist(imgxy, refxy, r): return h[0].T -def _estimate_2dhist_shift(imgxy, refxy, searchrad=3.0): +def _estimate_2dhist_shift(imgxy, refxy, searchrad=3.0, pscale=1.0, units=None): """ Create a 2D matrix-histogram which contains the delta between each XY position and each UV position. Then estimate initial offset between catalogs. + + ``pscale`` is used to make bins of size approximately equal to + image pixel. + """ log.info("Computing initial guess for X and Y shifts...") + if units is None: + units = 'tangent plane units' # create ZP matrix - zpmat = _xy_2dhist(imgxy, refxy, r=searchrad) + zpmat = _xy_2dhist(imgxy / pscale, refxy / pscale, r=searchrad / pscale) nonzeros = np.count_nonzero(zpmat) if nonzeros == 0: # no matches within search radius. Return (0, 0): - log.warning("No matches found within a search radius of {:g} pixels." - .format(searchrad)) + log.warning( + f"No matches found within a search radius of {searchrad:g} ({units})." + ) return 0.0, 0.0 elif nonzeros == 1: # only one non-zero bin: yp, xp = np.unravel_index(np.argmax(zpmat), zpmat.shape) maxval = zpmat[yp, xp] - xp -= searchrad - yp -= searchrad - log.info("Found initial X and Y shifts of {:.4g}, {:.4g} " - "based on a single non-zero bin and {} matches" - .format(xp, yp, int(maxval))) + xp = pscale * xp - searchrad + yp = pscale * yp - searchrad + + log.info( + f"Found initial X and Y shifts of {xp:.4g}, {yp:.4g} ({units}) " + f"based on a single non-zero bin and {int(maxval):d} matches." + ) return xp, yp - (xp, yp), fit_status, fit_sl = _find_peak(zpmat, peak_fit_box=5, - mask=zpmat > 0) + (xp, yp), fit_status, fit_sl = _find_peak( + zpmat, + peak_fit_box=5, + mask=zpmat > 0 + ) + if fit_status.startswith('ERROR'): - log.warning("No valid shift found within a search radius of {:g} " - "pixels.".format(searchrad)) + log.warning( + f"No valid shift found within a search radius of {searchrad:g} {units}." + ) return 0.0, 0.0 - xp -= searchrad - yp -= searchrad + xp = pscale * xp - searchrad + yp = pscale * yp - searchrad if fit_status == 'WARNING:EDGE': log.info("Found peak in the 2D histogram lies at the edge of the " @@ -334,16 +356,18 @@ def _estimate_2dhist_shift(imgxy, refxy, searchrad=3.0): if bkg > 0: # pragma: no branch bkg = zpmat[zpmat_mask].mean() sig = maxval / np.sqrt(bkg) - log.info("Found initial X and Y shifts of {:.4g}, {:.4g} " - "with significance of {:.4g} and {:d} matches." - .format(xp, yp, sig, flux)) + log.info( + f"Found initial X and Y shifts of {xp:.4g}, {yp:.4g} ({units}) " + f"with significance of {sig:.4g} and {flux:d} matches." + ) else: log.warning("Unable to estimate significance of the detection of the " "initial shift.") - log.info("Found initial X and Y shifts of {:.4g}, {:.4g} " - "with {:d} matches." - .format(xp, yp, flux)) + log.info( + f"Found initial X and Y shifts of {xp:.4g}, {yp:.4g} ({units}) " + f"with {flux:d} matches." + ) return xp, yp diff --git a/tweakwcs/tests/test_matchutils.py b/tweakwcs/tests/test_matchutils.py index e2772f4..2a3216d 100644 --- a/tweakwcs/tests/test_matchutils.py +++ b/tweakwcs/tests/test_matchutils.py @@ -199,14 +199,14 @@ def test_estimate_2dhist_shift_one_bin(shift): imgxy = np.zeros((1, 2)) refxy = imgxy - shift expected = 2 * (0 if shift > 3 else shift, ) - assert _estimate_2dhist_shift(imgxy, refxy, searchrad=3) == expected + assert _estimate_2dhist_shift(imgxy, refxy, searchrad=3, pscale=1.0) == expected def test_estimate_2dhist_shift_edge(): imgxy = np.array([[0, 0], [0, 1], [3, 4], [7, 8]]) shifts = np.array([[3, 0], [3, 0], [1, 2], [0, 1]]) refxy = imgxy - shifts - assert _estimate_2dhist_shift(imgxy, refxy, searchrad=3) == (3.0, 0.0) + assert _estimate_2dhist_shift(imgxy, refxy, searchrad=3, pscale=1.0) == (3.0, 0.0) def test_estimate_2dhist_shift_fit_failed(monkeypatch): @@ -218,15 +218,15 @@ def fake_find_peak(data, peak_fit_box=5, mask=None): imgxy = np.array([[0, 0], [0, 1], [3, 4], [7, 8]]) shifts = np.array([[3, 0], [3, 0], [1, 2], [0, 1]]) refxy = imgxy - shifts - assert _estimate_2dhist_shift(imgxy, refxy, searchrad=3) == (0.0, 0.0) + assert _estimate_2dhist_shift(imgxy, refxy, searchrad=3, pscale=1.0) == (0.0, 0.0) def test_estimate_2dhist_shift_two_equal_maxima(caplog): imgxy = np.array([[0, 1], [0, 1]]) refxy = np.array([[1, 0], [0, 2]]) - assert _estimate_2dhist_shift(imgxy, refxy, searchrad=3) == (-0.5, 0.0) + assert _estimate_2dhist_shift(imgxy, refxy, searchrad=3, pscale=1.0) == (-0.5, 0.0) assert (caplog.record_tuples[-1][2] == "Found initial X and Y shifts of " - "-0.5, 0 with 4 matches." and + "-0.5, 0 (tangent plane units) with 4 matches." and caplog.record_tuples[-1][1] == logging.INFO) assert (caplog.record_tuples[-2][2] == "Unable to estimate significance " "of the detection of the initial shift." and @@ -262,9 +262,15 @@ def test_tpmatch_bad_pars(searchrad, separation, tolerance): DummyWCSCorrector(WCS()), KeyError), ]) def test_tpmatch_bad_call_pars(refcat, imcat, tp_wcs, exception): + tp_pscale = tp_wcs.tanp_center_pixel_scale if tp_wcs else 1.0 tpmatch = XYXYMatch() with pytest.raises(exception): - tpmatch(refcat, imcat, tp_wcs) + tpmatch( + refcat, + imcat, + tp_pscale=tp_pscale, + tp_units=None if tp_wcs is None else tp_wcs.units + ) @pytest.mark.parametrize('tp_wcs, use2dhist', [ @@ -275,14 +281,16 @@ def test_tpmatch_bad_call_pars(refcat, imcat, tp_wcs, exception): ]) def test_tpmatch(tp_wcs, use2dhist): tpmatch = XYXYMatch(use2dhist=use2dhist) - if tp_wcs: - imcat = Table([[1], [1]], names=('x', 'y'), meta={'name': None}) - refcat = Table([[1], [1]], names=('RA', 'DEC'), meta={'name': None}) - else: - refcat = Table([[1], [1]], names=('TPx', 'TPy'), meta={'name': None}) - imcat = Table([[1], [1]], names=('TPx', 'TPy'), meta={'name': None}) - - tpmatch(refcat, imcat, tp_wcs) + refcat = Table([[1], [1]], names=('TPx', 'TPy'), meta={'name': None}) + imcat = Table([[1], [1]], names=('TPx', 'TPy'), meta={'name': None}) + + tp_pscale = tp_wcs.tanp_center_pixel_scale if tp_wcs else 1.0 + tpmatch( + refcat, + imcat, + tp_pscale=tp_pscale, + tp_units=None if tp_wcs is None else tp_wcs.units + ) def test_match_catalogs_abc(): diff --git a/tweakwcs/tests/test_multichip_fitswcs.py b/tweakwcs/tests/test_multichip_fitswcs.py index 00abd70..113c6ca 100644 --- a/tweakwcs/tests/test_multichip_fitswcs.py +++ b/tweakwcs/tests/test_multichip_fitswcs.py @@ -10,7 +10,7 @@ import tweakwcs -def _match(x, y): +def _match(x, y, tp_pscale, tp_units, **kwargs): lenx = len(x) leny = len(y) if lenx == leny: diff --git a/tweakwcs/tests/test_multichip_jwst.py b/tweakwcs/tests/test_multichip_jwst.py index 21af83b..3533842 100644 --- a/tweakwcs/tests/test_multichip_jwst.py +++ b/tweakwcs/tests/test_multichip_jwst.py @@ -112,7 +112,7 @@ def _make_gwcs_wcs(fits_hdr): return gw -def _match(x, y): +def _match(x, y, tp_pscale, tp_units, **kwargs): lenx = len(x) leny = len(y) if lenx == leny: diff --git a/tweakwcs/tests/test_wcsimage.py b/tweakwcs/tests/test_wcsimage.py index 32860d4..95892b8 100644 --- a/tweakwcs/tests/test_wcsimage.py +++ b/tweakwcs/tests/test_wcsimage.py @@ -236,6 +236,24 @@ def test_wcsgroupcat_update_bb_no_images(mock_fits_wcs, rect_imcat): assert len(g.polygon) == 0 +def test_wcsgroupcat_empty_cat(mock_fits_wcs, rect_imcat): + imcat = Table([[], [], [], []], names=('x', 'y', 'TPx', 'TPy')) + corr = FITSWCSCorrector(mock_fits_wcs) + + ra, dec = mock_fits_wcs.all_pix2world(rect_imcat.catalog['x'], + rect_imcat.catalog['y'], 0) + refcat = Table([ra, dec], names=('RA', 'DEC')) + ref = RefCatalog(refcat) + + w = WCSImageCatalog(imcat, corr) + ref.calc_tanp_xy(tanplane_wcs=rect_imcat.corrector) + g = WCSGroupCatalog([w]) + g.calc_tanp_xy(tanplane_wcs=rect_imcat.corrector) + + nmatches, *_ = g.match2ref(ref, match=XYXYMatch()) + assert nmatches == 0 + + def test_wcsgroupcat_create_group_catalog(mock_fits_wcs, rect_imcat): w1 = copy.deepcopy(rect_imcat) w2 = copy.deepcopy(rect_imcat) diff --git a/tweakwcs/wcsimage.py b/tweakwcs/wcsimage.py index 509e4a4..5b73921 100644 --- a/tweakwcs/wcsimage.py +++ b/tweakwcs/wcsimage.py @@ -825,6 +825,20 @@ def match2ref(self, refcat, match=None): ``'TPy'`` that represent the source coordinates in some common (to both catalogs) coordinate system. + Returns + ------- + + nmatches: int + Number of found matches. + + mref_idx: numpy.ndarray + Integer array indicating indices of sources in the reference + catalog that were matched to sources in group's ``catalog``. + + minput_idx: numpy.ndarray + Integer array indicating indices of sources in group's ``catalog`` + that were matched to sources in the reference catalog. + """ colnames = self._catalog.colnames catlen = len(self._catalog) @@ -848,7 +862,22 @@ def match2ref(self, refcat, match=None): raise RuntimeError("'calc_tanp_xy()' should have been run " "prior to match2ref()") - mref_idx, minput_idx = match(refcat.catalog, self._catalog) + if catlen == 0: + return 0, np.array([], dtype=int), np.array([], dtype=int) + + try: + tp_pscale = self._images[0].corrector.tanp_center_pixel_scale + except NotImplementedError: + tp_pscale = 1.0 + finally: + tp_units = self._images[0].corrector.units + + mref_idx, minput_idx = match( + refcat.catalog, + self._catalog, + tp_pscale=tp_pscale, + tp_units=tp_units + ) nmatches = len(mref_idx) # matched_ref_id: