From dab1009f749f1c4414b2282d5a7aeb9575f8a774 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Wed, 8 Jan 2025 01:40:54 -0500 Subject: [PATCH] Add unit tests that use gwcs. Fix how the bounding box is handled. (#167) * Add unit tests that use gwcs * Fix how bounding box is handled --- drizzle/tests/helpers.py | 144 +++++++++++++++++++++++++++++++++++- drizzle/tests/test_utils.py | 29 +++++--- drizzle/utils.py | 11 +++ pyproject.toml | 1 + 4 files changed, 173 insertions(+), 12 deletions(-) diff --git a/drizzle/tests/helpers.py b/drizzle/tests/helpers.py index 53d5d1b..0e44c4d 100644 --- a/drizzle/tests/helpers.py +++ b/drizzle/tests/helpers.py @@ -1,9 +1,21 @@ import os +import gwcs import numpy as np +from gwcs.coordinate_frames import CelestialFrame, Frame2D -from astropy import wcs +from astropy import coordinates as coord +from astropy import units +from astropy import wcs as fits_wcs from astropy.io import fits +from astropy.modeling.models import ( + Mapping, + Pix2Sky_TAN, + Polynomial2D, + RotateNative2Celestial, + Shift, +) +from astropy.modeling.projections import AffineTransformation2D __all__ = ["wcs_from_file"] @@ -11,9 +23,36 @@ DATA_DIR = os.path.join(TEST_DIR, 'data') -def wcs_from_file(filename, ext=None, return_data=False, crpix_shift=None): +def wcs_from_file(filename, ext=None, return_data=False, crpix_shift=None, + wcs_type="fits"): """ Read the WCS from a ".fits" file. + + Parameters + ---------- + filename : str + Name of the file to load WCS from. + + ext : int, None, optional + Extension number to load the WCS from. When `None`, the WCS will be + loaded from the first extension containing a WCS. + + return_data : bool, optional + When `True`, this function will return a tuple with first item + being the WCS and the second item being the image data array. + + crpix_shift : tuple, None, optional + A tuple of two values to be added to header CRPIX values before + creating the WCS. This effectively introduces a constant shift + in the image coordinate system. + + wcs_type : {"fits", "gwcs"}, optional + Return either a FITS WCS or a gwcs. + + Returns + ------- + WCS or tuple of WCS and image data + """ full_file_name = os.path.join(DATA_DIR, filename) path = os.path.join(DATA_DIR, full_file_name) @@ -38,9 +77,12 @@ def wcs_from_file(filename, ext=None, return_data=False, crpix_shift=None): hdr["CRPIX1"] += crpix_shift[0] hdr["CRPIX2"] += crpix_shift[1] - result = wcs.WCS(hdr, hdu) + result = fits_wcs.WCS(hdr, hdu) result.array_shape = shape + if wcs_type == "gwcs": + result = _gwcs_from_hst_fits_wcs(result) + if return_data: result = (result, ) if not isinstance(return_data, (list, tuple)): @@ -50,3 +92,99 @@ def wcs_from_file(filename, ext=None, return_data=False, crpix_shift=None): result = result + data return result + + +def _gwcs_from_hst_fits_wcs(w): + # NOTE: this function ignores table distortions + def coeffs_to_poly(mat, degree): + pol = Polynomial2D(degree=degree) + for i in range(mat.shape[0]): + for j in range(mat.shape[1]): + if 0 < i + j <= degree: + setattr(pol, f'c{i}_{j}', mat[i, j]) + return pol + + nx, ny = w.pixel_shape + x0, y0 = w.wcs.crpix - 1 + + cd = w.wcs.piximg_matrix + + if w.sip is None: + # construct GWCS: + det2sky = ( + (Shift(-x0) & Shift(-y0)) | + Pix2Sky_TAN() | RotateNative2Celestial(*w.wcs.crval, 180) + ) + else: + cfx, cfy = np.dot(cd, [w.sip.a.ravel(), w.sip.b.ravel()]) + a = np.reshape(cfx, w.sip.a.shape) + b = np.reshape(cfy, w.sip.b.shape) + a[1, 0] = cd[0, 0] + a[0, 1] = cd[0, 1] + b[1, 0] = cd[1, 0] + b[0, 1] = cd[1, 1] + + polx = coeffs_to_poly(a, w.sip.a_order) + poly = coeffs_to_poly(b, w.sip.b_order) + + sip = Mapping((0, 1, 0, 1)) | (polx & poly) + + # construct GWCS: + det2sky = ( + (Shift(-x0) & Shift(-y0)) | sip | + Pix2Sky_TAN() | RotateNative2Celestial(*w.wcs.crval, 180) + ) + + detector_frame = Frame2D( + name="detector", + axes_names=("x", "y"), + unit=(units.pix, units.pix) + ) + sky_frame = CelestialFrame( + reference_frame=getattr(coord, w.wcs.radesys).__call__(), + name=w.wcs.radesys, + unit=(units.deg, units.deg) + ) + pipeline = [(detector_frame, det2sky), (sky_frame, None)] + gw = gwcs.wcs.WCS(pipeline) + gw.array_shape = w.array_shape + gw.bounding_box = ((-0.5, nx - 0.5), (-0.5, ny - 0.5)) + + if w.sip is not None: + # compute inverse SIP and re-create output GWCS + + # compute inverse SIP: + hdr = gw.to_fits_sip( + max_inv_pix_error=1e-5, + inv_degree=None, + npoints=64, + crpix=w.wcs.crpix, + projection='TAN', + verbose=False + ) + winv = fits_wcs.WCS(hdr) + ap = winv.sip.ap.copy() + bp = winv.sip.bp.copy() + ap[1, 0] += 1 + bp[0, 1] += 1 + polx_inv = coeffs_to_poly(ap, winv.sip.ap_order) + poly_inv = coeffs_to_poly(bp, winv.sip.bp_order) + af = AffineTransformation2D( + matrix=np.linalg.inv(winv.wcs.piximg_matrix) + ) + + # set analytical inverses: + sip.inverse = af | Mapping((0, 1, 0, 1)) | (polx_inv & poly_inv) + + # construct GWCS: + det2sky = ( + (Shift(-x0) & Shift(-y0)) | sip | + Pix2Sky_TAN() | RotateNative2Celestial(*w.wcs.crval, 180) + ) + + pipeline = [(detector_frame, det2sky), (sky_frame, None)] + gw = gwcs.wcs.WCS(pipeline) + gw.array_shape = w.array_shape + gw.bounding_box = ((-0.5, nx - 0.5), (-0.5, ny - 0.5)) + + return gw diff --git a/drizzle/tests/test_utils.py b/drizzle/tests/test_utils.py index d5fe2a2..5e263de 100644 --- a/drizzle/tests/test_utils.py +++ b/drizzle/tests/test_utils.py @@ -25,14 +25,17 @@ def test_map_rectangular(): assert_equal(pixmap[5, 500], (500, 5)) -def test_map_to_self(): +@pytest.mark.parametrize( + "wcs_type", ["fits", "gwcs"] +) +def test_map_to_self(wcs_type): """ Map a pixel array to itself. Should return the same array. """ - input_wcs = wcs_from_file("j8bt06nyq_sip_flt.fits", ext=1) + input_wcs = wcs_from_file("j8bt06nyq_sip_flt.fits", ext=1, wcs_type=wcs_type) shape = input_wcs.array_shape - ok_pixmap = np.indices(shape, dtype='float32') + ok_pixmap = np.indices(shape, dtype='float64') ok_pixmap = ok_pixmap.transpose() pixmap = calc_pixmap(input_wcs, input_wcs) @@ -47,9 +50,10 @@ def test_map_to_self(): pixmap = calc_pixmap(input_wcs, input_wcs, (12, 34)) assert_equal(pixmap.shape, (12, 34, 2)) - # Check that an exception is raised for WCS without pixel_shape or + # Check that an exception is raised for WCS without pixel_shape and # bounding_box: input_wcs.pixel_shape = None + input_wcs.bounding_box = None with pytest.raises(ValueError): calc_pixmap(input_wcs, input_wcs) @@ -68,17 +72,24 @@ def test_map_to_self(): assert_equal(pixmap.shape, ok_pixmap.shape) -def test_translated_map(): +@pytest.mark.parametrize( + "wcs_type", ["fits", "gwcs"] +) +def test_translated_map(wcs_type): """ Map a pixel array to at translated array. """ - first_wcs = wcs_from_file("j8bt06nyq_sip_flt.fits", ext=1) + first_wcs = wcs_from_file( + "j8bt06nyq_sip_flt.fits", + ext=1, + wcs_type=wcs_type + ) second_wcs = wcs_from_file( "j8bt06nyq_sip_flt.fits", ext=1, - crpix_shift=(-2, -2) # shift loaded WCS by subtracting this from CRPIX + crpix_shift=(-2, -2), # shift loaded WCS by adding this to CRPIX + wcs_type=wcs_type ) - assert np.allclose(second_wcs.wcs.crpix, (510, 510)) ok_pixmap = np.indices(first_wcs.array_shape, dtype='float32') - 2.0 ok_pixmap = ok_pixmap.transpose() @@ -88,7 +99,7 @@ def test_translated_map(): # Got x-y transpose right assert_equal(pixmap.shape, ok_pixmap.shape) # Mapping an array to a translated array - assert_almost_equal(pixmap, ok_pixmap, decimal=5) + assert_almost_equal(pixmap[2:, 2:], ok_pixmap[2:, 2:], decimal=5) def test_estimate_pixel_scale_ratio(): diff --git a/drizzle/utils.py b/drizzle/utils.py index a3362af..d0d4ebf 100644 --- a/drizzle/utils.py +++ b/drizzle/utils.py @@ -61,6 +61,17 @@ def calc_pixmap(wcs_from, wcs_to, shape=None): shape = wcs_from.array_shape if shape is None: if (bbox := getattr(wcs_from, "bounding_box", None)) is not None: + try: + # to avoid dependency on astropy just to check whether + # the bounding box is an instance of + # modeling.bounding_box.ModelBoundingBox, we try to + # directly use and bounding_box(order='F') and if it fails, + # fall back to converting the bounding box to a tuple + # (of intervals): + bbox = bbox.bounding_box(order='F') + except AttributeError: + bbox = tuple(bbox) + if (nd := np.ndim(bbox)) == 1: bbox = (bbox, ) if nd > 1: diff --git a/pyproject.toml b/pyproject.toml index 1b6a769..361e495 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ Documentation = "http://spacetelescope.github.io/drizzle/" [project.optional-dependencies] test = [ "astropy", + "gwcs", "pytest", "pytest-cov", "pytest-doctestplus",