From 00488aa00fa66eb2eb2b8c94f2e6fb812d570018 Mon Sep 17 00:00:00 2001 From: Theodore Kisner Date: Wed, 13 Sep 2023 05:32:31 -0700 Subject: [PATCH 1/5] Refactor of WCS pixelization operator - Move the application of source centering in the projection to a separate helper function in `pointing_utils.py`. Thanks to @gabrielecoppi for identifying this fix. Optionally use this new function when computing the scan range for autoscaling. - In `PixelsWCS`: - Add a new general class method that computes the WCS parameters. - Add support for SFL projection. - Allow projection traits to be changed in any order and only recompute the WCS if needed when exec() is called. - Default to a single submap, which is the most efficient choice for the common case of data distributed by detector and many observations co-incident on the sky. - In the PixelsWCS unit tests: - Ensure projection and plotting works for every supported projection type with both fixed parameters and autoscaling. - Test mapmaking in both normal mode and with source-centered projections in RA/DEC and Az/El. - In `plot_wcs_maps`: - Set the figure size based on the DPI and the actual size of the image in pixels. - Set the unhit pixels to gray. - Allow specifying the color map, and default to one of the perceptially uniform ones. --- src/toast/ops/pixels_wcs.py | 454 +++++++++++++++------------ src/toast/pixels.py | 5 +- src/toast/pixels_io_wcs.py | 18 +- src/toast/pointing_utils.py | 148 +++++---- src/toast/tests/ops_pointing_wcs.py | 458 ++++++++++++++++++++-------- src/toast/vis.py | 80 +++-- 6 files changed, 753 insertions(+), 410 deletions(-) diff --git a/src/toast/ops/pixels_wcs.py b/src/toast/ops/pixels_wcs.py index 792b139b0..97a68d62f 100644 --- a/src/toast/ops/pixels_wcs.py +++ b/src/toast/ops/pixels_wcs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file. +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. @@ -8,12 +8,14 @@ import traitlets from astropy import units as u from astropy.wcs import WCS +import astropy.io.fits as af from .. import qarray as qa +from ..instrument_coords import quat_to_xieta from ..mpi import MPI from ..observation import default_values as defaults from ..pixels import PixelDistribution -from ..pointing_utils import scan_range_lonlat +from ..pointing_utils import scan_range_lonlat, center_offset_lonlat from ..timing import function_timer from ..traits import Bool, Instance, Int, Tuple, Unicode, trait_docs from ..utils import Environment, Logger @@ -34,9 +36,9 @@ class PixelsWCS(Operator): If the view trait is not specified, then this operator will use the same data view as the detector pointing operator when computing the pointing matrix pixels. - This uses the astropy wcs utilities to build the projection parameters. By - default, the world to pixel conversion is performed with internal, optimized code - unless use_astropy is set to True. + This uses the astropy wcs utilities to build the projection parameters. Eventually + this operator will use internal kernels for the projection unless `use_astropy` + is set to True. """ @@ -50,11 +52,20 @@ class PixelsWCS(Operator): help="Operator that translates boresight pointing into detector frame", ) - projection = Unicode("CAR", help="Supported values are CAR, CEA, MER, ZEA, TAN") + fits_header = Unicode( + None, + allow_none=True, + help="FITS file containing header to use with pre-existing WCS parameters", + ) + + coord_frame = Unicode("EQU", help="Supported values are AZEL, EQU, GAL, ECL") + + projection = Unicode( + "CAR", help="Supported values are CAR, CEA, MER, ZEA, TAN, SFL" + ) center = Tuple( - (180 * u.degree, 0 * u.degree), - allow_none=True, + tuple(), help="The center Lon/Lat coordinates (Quantities) of the projection", ) @@ -70,17 +81,17 @@ class PixelsWCS(Operator): ) auto_bounds = Bool( - False, + True, help="If True, set the bounding box based on boresight and field of view", ) dimensions = Tuple( - (710, 350), + (1000, 1000), help="The Lon/Lat pixel dimensions of the projection", ) resolution = Tuple( - (0.5 * u.degree, 0.5 * u.degree), + tuple(), help="The Lon/Lat projection resolution (Quantities) along the 2 axes", ) @@ -90,7 +101,7 @@ class PixelsWCS(Operator): pixels = Unicode("pixels", help="Observation detdata key for output pixel indices") - submaps = Int(10, help="Number of submaps to use") + submaps = Int(1, help="Number of submaps to use") create_dist = Unicode( None, @@ -128,198 +139,245 @@ def _check_detector_pointing(self, proposal): @traitlets.validate("wcs_projection") def _check_wcs_projection(self, proposal): check = proposal["value"] - if check not in ["CAR", "CEA", "MER", "ZEA", "TAN"]: + if check not in ["CAR", "CEA", "MER", "ZEA", "TAN", "SFL"]: raise traitlets.TraitError("Invalid WCS projection name") return check def __init__(self, **kwargs): super().__init__(**kwargs) - # If running with all default values, the 'observe' function will not - # have been called yet. - if not hasattr(self, "_local_submaps"): - self._set_wcs( - self.projection, - self.center, - self.bounds, - self.dimensions, - self.resolution, - ) + # Track whether we need to recompute autobounds self._done_auto = False + # Track whether we need to recompute the WCS projection + self._done_wcs = False @traitlets.observe("auto_bounds") def _reset_auto_bounds(self, change): # Track whether we need to recompute the bounds. - if change["new"]: - # enabling + old_val = change["old"] + new_val = change["new"] + if new_val != old_val: self._done_auto = False - else: - self._done_auto = True + self._done_wcs = False @traitlets.observe("center_offset") def _reset_auto_center(self, change): - # Track whether we need to recompute the bounds. - if change["new"] is not None: - if self.auto_bounds: - self._done_auto = False + old_val = change["old"] + new_val = change["new"] + # Track whether we need to recompute the projection + if new_val != old_val: + self._done_wcs = False + self._done_auto = False @traitlets.observe("projection", "center", "bounds", "dimensions", "resolution") def _reset_wcs(self, change): # (Re-)initialize the WCS projection when one of these traits change. - # Current values: - proj = str(self.projection) - center = self.center - if len(center) > 0: - center = tuple(self.center) - bounds = self.bounds - if len(bounds) > 0: - bounds = tuple(self.bounds) - dims = self.dimensions - if len(dims) > 0: - dims = tuple(self.dimensions) - res = self.resolution - if len(res) > 0: - res = tuple(self.resolution) - - # Update to the trait that changed - if change["name"] == "projection": - proj = change["new"] - if change["name"] == "center": - center = change["new"] - if len(center) > 0: - bounds = tuple() - if change["name"] == "bounds": - bounds = change["new"] - if len(bounds) > 0: - center = tuple() - if len(dims) > 0 and len(res) > 0: - # Most likely the user cares about the resolution more... - dims = tuple() - if change["name"] == "dimensions": - dims = change["new"] - if len(dims) > 0 and len(bounds) > 0: - res = tuple() - if change["name"] == "resolution": - res = change["new"] - if len(res) > 0 and len(bounds) > 0: - dims = tuple() - self._set_wcs(proj, center, bounds, dims, res) - self.projection = proj - self.center = center - self.bounds = bounds - self.dimensions = dims - self.resolution = res - - def _set_wcs(self, proj, center, bounds, dims, res): + old_val = change["old"] + new_val = change["new"] + if old_val != new_val: + self._done_wcs = False + self._done_auto = False + + @classmethod + def create_wcs( + cls, + coord="EQU", + proj="CAR", + center_deg=None, + bounds_deg=None, + res_deg=None, + dims=None, + ): + """Create a WCS object given projection parameters. + + Either the `center_deg` or `bounds_deg` parameters must be specified, + but not both. + + When determining the pixel density in the projection, exactly two + parameters from the set of `bounds_deg`, `res_deg` and `dims` must be + specified. + + Args: + coord (str): The coordinate frame name. + proj (str): The projection type. + center_deg (tuple): The (lon, lat) projection center in degrees. + bounds_deg (tuple): The (lon_min, lon_max, lat_min, lat_max) + values in degrees. + res_deg (tuple): The (lon, lat) resolution in degrees. + dims (tuple): The (lon, lat) projection size in pixels. + + Returns: + (WCS, shape): The instantiated WCS object and final shape. + + """ log = Logger.get() - log.verbose(f"PixelsWCS: set_wcs {proj}, {center}, {bounds}, {dims}, {res}") - if len(res) > 0: - res = np.array( - [ - res[0].to_value(u.degree), - res[1].to_value(u.degree), - ] - ) - if len(dims) > 0: - dims = np.array([self.dimensions[0], self.dimensions[1]]) - - if len(bounds) == 0: - # Using center, need both resolution and dimensions - if len(center) == 0: - # Cannot calculate yet - return - if len(res) == 0 or len(dims) == 0: - # Cannot calculate yet - return - pos = np.array( - [ - center[0].to_value(u.degree), - center[1].to_value(u.degree), - ] - ) - mid = pos + + # Compute projection center + if center_deg is not None: + # We are specifying the center. Bounds should not be set and we should + # have both resolution and dimensions + if bounds_deg is not None: + msg = f"PixelsWCS: only one of center and bounds should be set." + log.error(msg) + raise RuntimeError(msg) + if res_deg is None or dims is None: + msg = f"PixelsWCS: when center is set, both resolution and dimensions" + msg += f" are required." + log.error(msg) + raise RuntimeError(msg) + crval = np.array(center_deg, dtype=np.float64) else: - # Using bounds, exactly one of resolution or dimensions specified - if len(res) > 0 and len(dims) > 0: - # Cannot calculate yet - return - - # Max Longitude - lower_left_lon = bounds[0].to_value(u.degree) - # Min Latitude - lower_left_lat = bounds[2].to_value(u.degree) - # Min Longitude - upper_right_lon = bounds[1].to_value(u.degree) - # Max Latitude - upper_right_lat = bounds[3].to_value(u.degree) - - pos = np.array( - [[lower_left_lon, lower_left_lat], [upper_right_lon, upper_right_lat]] - ) - mid = np.mean(pos, axis=0) - - def _wcs_ref_res(w, p, r, d): - w.wcs.crpix = [1, 1] - if len(r) == 0: - w.wcs.cdelt = [1, 1] - corners = w.wcs_world2pix(p, 1) - w.wcs.cdelt *= (corners[1] - corners[0]) / d - else: - w.wcs.cdelt = r - if p.ndim == 2: - w.wcs.cdelt[p[1] < p[0]] *= -1 - if p.ndim == 1: - if len(dims) > 0: - off = w.wcs_world2pix(p[None], 0)[0] - w.wcs.crpix = np.array(d) / 2.0 + 0.5 - off + # Not using center, bounds is required + if bounds_deg is None: + msg = f"PixelsWCS: when center is not specified, bounds required." + log.error(msg) + raise RuntimeError(msg) + mid_lon = 0.5 * (bounds_deg[1] + bounds_deg[0]) + mid_lat = 0.5 * (bounds_deg[3] + bounds_deg[2]) + crval = np.array([mid_lon, mid_lat], dtype=np.float64) + # Either resolution or dimensions should be specified + if res_deg is not None: + # Using resolution + if dims is not None: + msg = f"PixelsWCS: when using bounds, only one of resolution or" + msg += f" dimensions must be specified." + log.error(msg) + raise RuntimeError(msg) else: - off = w.wcs_world2pix(p[0, None], 0)[0] + 0.5 - w.wcs.crpix -= off + # Using dimensions + if res_deg is not None: + msg = f"PixelsWCS: when using bounds, only one of resolution or" + msg += f" dimensions must be specified." + log.error(msg) + raise RuntimeError(msg) + + # Create the WCS object. + # CTYPE1 = Longitude + # CTYPE2 = Latitude + wcs = WCS(naxis=2) + + if coord == "AZEL": + # FIXME: The WCS standard does not define a keyword for + # horizontal coordinates. How should we deal with this? + # Also AZ is reversed from normal conventions- should we + # negate CDELT? + coordstr = ("RA--", "DEC-") + elif coord == "EQU": + coordstr = ("RA--", "DEC-") + elif coord == "GAL": + coordstr = ("GLON", "GLAT") + elif coord == "ECL": + coordstr = ("ELON", "ELAT") + else: + msg = f"Unsupported coordinate frame '{coord}'" + raise RuntimeError(msg) - self.wcs = WCS(naxis=2) if proj == "CAR": - self.wcs.wcs.ctype = ["RA---CAR", "DEC--CAR"] - self.wcs.wcs.crval = np.array([mid[0], 0]) + wcs.wcs.ctype = [f"{coordstr[0]}-CAR", f"{coordstr[1]}-CAR"] + wcs.wcs.crval = crval elif proj == "CEA": - self.wcs.wcs.ctype = ["RA---CEA", "DEC--CEA"] - self.wcs.wcs.crval = np.array([mid[0], 0]) - lam = np.cos(np.deg2rad(mid[1])) ** 2 - self.wcs.wcs.set_pv([(2, 1, lam)]) + wcs.wcs.ctype = [f"{coordstr[0]}-CEA", f"{coordstr[1]}-CEA"] + wcs.wcs.crval = crval + lam = np.cos(np.deg2rad(crval[1])) ** 2 + wcs.wcs.set_pv([(2, 1, lam)]) elif proj == "MER": - self.wcs.wcs.ctype = ["RA---MER", "DEC--MER"] - self.wcs.wcs.crval = np.array([mid[0], 0]) + wcs.wcs.ctype = [f"{coordstr[0]}-MER", f"{coordstr[1]}-MER"] + wcs.wcs.crval = crval elif proj == "ZEA": - self.wcs.wcs.ctype = ["RA---ZEA", "DEC--ZEA"] - self.wcs.wcs.crval = mid + wcs.wcs.ctype = [f"{coordstr[0]}-ZEA", f"{coordstr[1]}-ZEA"] + wcs.wcs.crval = crval elif proj == "TAN": - self.wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] - self.wcs.wcs.crval = mid + wcs.wcs.ctype = [f"{coordstr[0]}-TAN", f"{coordstr[1]}-TAN"] + wcs.wcs.crval = crval + elif proj == "SFL": + wcs.wcs.ctype = [f"{coordstr[0]}-SFL", f"{coordstr[1]}-SFL"] + wcs.wcs.crval = crval else: msg = f"Invalid WCS projection name '{proj}'" raise ValueError(msg) - _wcs_ref_res(self.wcs, pos, res, dims) - if len(dims) == 0: - # Compute from the bounding box corners - lower_left = self.wcs.wcs_world2pix(np.array([[pos[0, 0], pos[0, 1]]]), 0)[ - 0 - ] - upper_right = self.wcs.wcs_world2pix(np.array([[pos[1, 0], pos[1, 1]]]), 0)[ - 0 - ] - self.wcs_shape = tuple( - np.round(np.abs(upper_right - lower_left)).astype(int) - ) + # Compute resolution. Note that we negate the longitudinal + # coordinate so that the resulting projections match expectations + # for plotting, etc. + if center_deg is not None: + wcs.wcs.cdelt = np.array([-res_deg[0], res_deg[1]]) + else: + if res_deg is not None: + wcs.wcs.cdelt = np.array([-res_deg[0], res_deg[1]]) + else: + # Compute CDELT from the bounding box and image size. + wcs.wcs.cdelt = np.array( + [ + -(bounds_deg[1] - bounds_deg[0]) / dims[0], + (bounds_deg[3] - bounds_deg[2]) / dims[1], + ] + ) + + # Compute shape of the projection + if dims is not None: + wcs_shape = tuple(dims) else: - self.wcs_shape = tuple(dims) - log.verbose(f"PixelsWCS: wcs_shape = {self.wcs_shape}") + # Compute from the bounding box corners + lower_left = wcs.wcs_world2pix( + np.array([[bounds_deg[0], bounds_deg[2]]]), 0 + )[0] + upper_right = wcs.wcs_world2pix( + np.array([[bounds_deg[1], bounds_deg[3]]]), 0 + )[0] + wcs_shape = tuple(np.round(np.abs(upper_right - lower_left)).astype(int)) + + # Set the reference pixel to the center of the projection + off = wcs.wcs_world2pix(crval.reshape((1, 2)), 0)[0] + wcs.wcs.crpix = 0.5 * np.array(wcs_shape, dtype=np.float64) + 0.5 + off + + return wcs, wcs_shape - self.pix_ra = self.wcs_shape[0] - self.pix_dec = self.wcs_shape[1] - self._n_pix = self.pix_ra * self.pix_dec + def set_wcs(self): + if self._done_wcs: + return + + log = Logger.get() + msg = f"PixelsWCS: set_wcs coord={self.coord_frame}, " + msg += f"proj={self.projection}, center={self.center}, bounds={self.bounds}" + msg += f", dims={self.dimensions}, res={self.resolution}" + log.verbose(msg) + + center_deg = None + if len(self.center) > 0: + if self.center_offset is None: + center_deg = ( + self.center[0].to_value(u.degree), + self.center[1].to_value(u.degree), + ) + else: + center_deg = (0.0, 0.0) + bounds_deg = None + if len(self.bounds) > 0: + bounds_deg = tuple([x.to_value(u.degree) for x in self.bounds]) + res_deg = None + if len(self.resolution) > 0: + res_deg = tuple([x.to_value(u.degree) for x in self.resolution]) + if len(self.dimensions) > 0: + dims = tuple(self.dimensions) + else: + dims = None + + self.wcs, self.wcs_shape = self.create_wcs( + coord=self.coord_frame, + proj=self.projection, + center_deg=center_deg, + bounds_deg=bounds_deg, + res_deg=res_deg, + dims=dims, + ) + + self.pix_lon = self.wcs_shape[0] + self.pix_lat = self.wcs_shape[1] + self._n_pix = self.pix_lon * self.pix_lat self._n_pix_submap = self._n_pix // self.submaps if self._n_pix_submap * self.submaps < self._n_pix: self._n_pix_submap += 1 - self._local_submaps = None + self._local_submaps = np.zeros(self.submaps, dtype=np.uint8) + self._done_wcs = True return @function_timer @@ -333,6 +391,18 @@ def _exec(self, data, detectors=None, **kwargs): if not self.use_astropy: raise NotImplementedError("Only astropy conversion is currently supported") + if self.fits_header is not None: + # with open(self.fits_header, "rb") as f: + # header = af.Header.fromfile(f) + raise NotImplementedError( + "Initialization from a FITS header not yet finished" + ) + + if self.coord_frame == "AZEL": + is_azimuth = True + else: + is_azimuth = False + if self.auto_bounds and not self._done_auto: # Pass through the boresight pointing for every observation and build # the maximum extent of the detector field of view. @@ -348,6 +418,7 @@ def _exec(self, data, detectors=None, **kwargs): flags=self.detector_pointing.shared_flags, flag_mask=self.detector_pointing.shared_flag_mask, field_of_view=None, + is_azimuth=is_azimuth, center_offset=self.center_offset, ) lonmin = min(lonmin, lnmin) @@ -369,18 +440,17 @@ def _exec(self, data, detectors=None, **kwargs): latmin = all_lonlatmin[1] * u.radian lonmax = all_lonlatmax[0] * u.radian latmax = all_lonlatmax[1] * u.radian - new_bounds = ( + self.bounds = ( lonmin.to(u.degree), lonmax.to(u.degree), latmin.to(u.degree), latmax.to(u.degree), ) - log.verbose(f"PixelsWCS auto_bounds set to {new_bounds}") - self.bounds = new_bounds + log.verbose(f"PixelsWCS: auto_bounds set to {self.bounds}") self._done_auto = True - if self._local_submaps is None and self.create_dist is not None: - self._local_submaps = np.zeros(self.submaps, dtype=np.uint8) + # Compute the projection if needed + self.set_wcs() # Expand detector pointing quats_name = self.detector_pointing.quats @@ -474,41 +544,38 @@ def _exec(self, data, detectors=None, **kwargs): center_lonlat = None if self.center_offset is not None: - center_lonlat = ob.shared[self.center_offset].data + center_lonlat = np.radians(ob.shared[self.center_offset].data) # Process all detectors for det in dets: for vslice in view_slices: # Timestream of detector quaternions quats = ob.detdata[quats_name][det][vslice] - view_samples = len(quats) - theta, phi, _ = qa.to_iso_angles(quats) - to_deg = 180.0 / np.pi - theta *= to_deg - phi *= to_deg - shift = phi >= 360.0 - phi[shift] -= 360.0 - shift = phi < 0.0 - phi[shift] += 360.0 - world_in = np.column_stack([phi, 90.0 - theta]) + if center_lonlat is None: + center_offset = None + else: + center_offset = center_lonlat[vslice] + + rel_lon, rel_lat = center_offset_lonlat( + quats, + center_offset=center_offset, + degrees=True, + is_azimuth=is_azimuth, + ) - if center_lonlat is not None: - world_in[:, 0] -= center_lonlat[vslice, 0] - world_in[:, 1] -= center_lonlat[vslice, 1] + world_in = np.column_stack([rel_lon, rel_lat]) rdpix = self.wcs.wcs_world2pix(world_in, 0) - if flags is not None: - # Set bad pointing to pixel -1 - bad_pointing = flags[vslice] != 0 - rdpix[bad_pointing] = -1 rdpix = np.array(np.around(rdpix), dtype=np.int64) ob.detdata[self.pixels][det, vslice] = ( - rdpix[:, 0] * self.pix_dec + rdpix[:, 1] + rdpix[:, 0] * self.pix_lat + rdpix[:, 1] ) bad_pointing = ob.detdata[self.pixels][det, vslice] >= self._n_pix + if flags is not None: + bad_pointing = np.logical_or(bad_pointing, flags[vslice] != 0) (ob.detdata[self.pixels][det, vslice])[bad_pointing] = -1 if self.create_dist is not None: @@ -520,7 +587,6 @@ def _exec(self, data, detectors=None, **kwargs): def _finalize(self, data, **kwargs): if self.create_dist is not None: - submaps = None if self.single_precision: submaps = np.arange(self.submaps, dtype=np.int32)[ self._local_submaps == 1 @@ -539,6 +605,8 @@ def _finalize(self, data, **kwargs): # Store a copy of the WCS information in the distribution object data[self.create_dist].wcs = self.wcs.deepcopy() data[self.create_dist].wcs_shape = tuple(self.wcs_shape) + # Reset the local submaps + self._local_submaps[:] = 0 return def _requires(self): diff --git a/src/toast/pixels.py b/src/toast/pixels.py index 96ee9198e..3107dbfed 100644 --- a/src/toast/pixels.py +++ b/src/toast/pixels.py @@ -71,15 +71,14 @@ def __init__(self, n_pix=None, n_submap=1000, local_submaps=None, comm=None): self._local_submaps = local_submaps self._comm = comm - self._glob2loc = None self._n_local = 0 + self._glob2loc = AlignedI64.zeros(self._n_submap) + self._glob2loc[:] = -1 if self._local_submaps is not None and len(self._local_submaps) > 0: if np.max(self._local_submaps) > self._n_submap - 1: raise RuntimeError("local submap indices out of range") self._n_local = len(self._local_submaps) - self._glob2loc = AlignedI64.zeros(self._n_submap) - self._glob2loc[:] = -1 for ilocal_submap, iglobal_submap in enumerate(self._local_submaps): self._glob2loc[iglobal_submap] = ilocal_submap diff --git a/src/toast/pixels_io_wcs.py b/src/toast/pixels_io_wcs.py index 1d153a3cb..4f9d2bebd 100644 --- a/src/toast/pixels_io_wcs.py +++ b/src/toast/pixels_io_wcs.py @@ -40,8 +40,8 @@ def submap_to_image(dist, submap, sdata, image): """ imshape = image.shape n_value = imshape[0] - n_cols = imshape[1] - n_rows = imshape[2] + n_rows = imshape[1] + n_cols = imshape[2] # Global pixel range of this submap s_offset = submap * dist.n_pix_submap @@ -64,7 +64,7 @@ def submap_to_image(dist, submap, sdata, image): if pix_offset + n_copy > s_end: n_copy = s_end - pix_offset sbuf_offset = pix_offset + row_offset - s_offset - image[ival, col, row_offset : row_offset + n_copy] = sdata[ + image[ival, row_offset : row_offset + n_copy, col] = sdata[ sbuf_offset : sbuf_offset + n_copy, ival ] @@ -88,8 +88,8 @@ def image_to_submap(dist, image, submap, sdata, scale=1.0): """ imshape = image.shape n_value = imshape[0] - n_cols = imshape[1] - n_rows = imshape[2] + n_rows = imshape[1] + n_cols = imshape[2] # Global pixel range of this submap s_offset = submap * dist.n_pix_submap @@ -113,7 +113,7 @@ def image_to_submap(dist, image, submap, sdata, scale=1.0): n_copy = s_end - pix_offset sbuf_offset = pix_offset + row_offset - s_offset sdata[sbuf_offset : sbuf_offset + n_copy, ival] = ( - scale * image[ival, col, row_offset : row_offset + n_copy] + scale * image[ival, row_offset : row_offset + n_copy, col] ) @@ -153,10 +153,12 @@ def collect_wcs_submaps(pix, comm_bytes=10000000): allowners = np.zeros_like(owners) dist.comm.Allreduce(owners, allowners, op=MPI.MIN) - # Create an image array for the output + # Create an image array for the output. The FITS image data is column + # major, and so our numpy array has the order of axes swapped. image = None + image_shape = (pix.n_value, dist.wcs_shape[1], dist.wcs_shape[0]) if rank == 0: - image = np.zeros((pix.n_value,) + dist.wcs_shape, dtype=pix.dtype) + image = np.zeros(image_shape, dtype=pix.dtype) n_val_submap = dist.n_pix_submap * pix.n_value diff --git a/src/toast/pointing_utils.py b/src/toast/pointing_utils.py index 37997e3bb..b9d0c8706 100644 --- a/src/toast/pointing_utils.py +++ b/src/toast/pointing_utils.py @@ -1,18 +1,72 @@ -# Copyright (c) 2015-2023 by the parties listed in the AUTHORS file. +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. # Pointing utility functions used by templates and operators - import numpy as np from astropy import units as u from . import qarray as qa +from .instrument_coords import quat_to_xieta from .mpi import MPI from .timing import GlobalTimers, Timer, function_timer +def center_offset_lonlat( + quats, + center_offset=None, + degrees=False, + is_azimuth=False, +): + """Compute relative longitude / latitude from a dynamic center position. + + Args: + quats (array): Input pointing quaternions + center_offset (array): Center longitude, latitude in radians for each sample + degrees (bool): If True, return longitude / latitude values in degrees + is_azimuth (bool): If True, we are using azimuth and the sign of the + longitude values should be negated + + Returns: + (tuple): The (longitude, latitude) arrays + + """ + if center_offset is None: + lon_rad, lat_rad, _ = qa.to_lonlat_angles(quats) + else: + if len(quats.shape) == 2: + n_samp = quats.shape[0] + else: + n_samp = 1 + if center_offset.shape[0] != n_samp: + msg = f"center_offset dimensions {center_offset.shape}" + msg += f" not compatible with {n_samp} quaternion values" + raise ValueError(msg) + q_center = qa.from_lonlat_angles( + center_offset[:, 0], + center_offset[:, 1], + np.zeros_like(center_offset[:, 0]), + ) + q_final = qa.mult(qa.inv(q_center), quats) + lon_rad, lat_rad, _ = quat_to_xieta(q_final) + if is_azimuth: + lon_rad = 2 * np.pi - lon_rad + # Normalize range + shift = lon_rad >= 2 * np.pi + lon_rad[shift] -= 2 * np.pi + shift = lon_rad < 0 + lon_rad[shift] += 2 * np.pi + # Convert units + if degrees: + lon = np.degrees(lon_rad) + lat = np.degrees(lat_rad) + else: + lon = lon_rad + lat = lat_rad + return (lon, lat) + + @function_timer def scan_range_lonlat( obs, @@ -50,20 +104,36 @@ def scan_range_lonlat( fov = obs.telescope.focalplane.field_of_view fp_radius = 0.5 * fov.to_value(u.radian) + # The observation samples we are considering if samples is None: slc = slice(0, obs.n_local_samples, 1) else: slc = samples - # Get the flags if needed. - fdata = None + # Apply the flags to boresight pointing if needed. + bore_quats = np.array(obs.shared[boresight].data[slc, :]) if flags is not None: - fdata = np.array(obs.shared[flags][slc]) + fdata = np.array(obs.shared[flags].data[slc]) fdata &= flag_mask + bore_quats = bore_quats[fdata == 0, :] - # work in parallel + # The remaining good samples we have left + n_good = bore_quats.shape[0] + + # Check that the top of the focalplane is below the zenith + _, el_bore, _ = qa.to_lonlat_angles(bore_quats) + elmax_bore = np.amax(el_bore) + if elmax_bore + fp_radius > np.pi / 2: + msg = f"The scan range includes the zenith." + msg += f" Max boresight elevation is {np.degrees(elmax_bore)} deg" + msg += f" and focalplane radius is {np.degrees(fp_radius)} deg." + msg += " Scan range facility cannot handle this case." + raise RuntimeError(msg) + + # Work in parallel rank = obs.comm.group_rank ntask = obs.comm.group_size + rank_slice = slice(rank, n_good, ntask) # Create a fake focalplane of detectors in a circle around the boresight xaxis, yaxis, zaxis = np.eye(3) @@ -76,52 +146,29 @@ def scan_range_lonlat( detquat = qa.mult(phirot, thetarot) detquats.append(detquat) - # Get fake detector pointing - + # Get source center positions if needed center_lonlat = None if center_offset is not None: - center_lonlat = np.array(obs.shared[center_offset][slc, :]) + center_lonlat = np.array( + (obs.shared[center_offset].data[slc, :])[rank_slice, :] + ) + # center_offset is in degrees center_lonlat[:, :] *= np.pi / 180.0 - lon = [] - lat = [] - quats = obs.shared[boresight][slc, :][rank::ntask].copy() - rank_good = slice(None) - if fdata is not None: - rank_good = fdata[rank::ntask] == 0 - - # Check that the top of the focalplane is below the zenith - theta_bore, _, _ = qa.to_iso_angles(quats) - el_bore = np.pi / 2 - theta_bore[rank_good] - elmax_bore = np.amax(el_bore) - if elmax_bore + fp_radius > np.pi / 2: - msg = f"The scan range includes the zenith." - msg += f" Max boresight elevation is {np.degrees(elmax_bore)} deg" - msg += f" and focalplane radius is {np.degrees(fp_radius)} deg." - msg += " Scan range facility cannot handle this case." - raise RuntimeError(msg) - + # Compute pointing of fake detectors + lon = list() + lat = list() for idet, detquat in enumerate(detquats): - theta, phi, _ = qa.to_iso_angles(qa.mult(quats, detquat)) - if center_lonlat is None: - if is_azimuth: - lon.append(2 * np.pi - phi[rank_good]) - else: - lon.append(phi[rank_good]) - lat.append(np.pi / 2 - theta[rank_good]) - else: - if is_azimuth: - lon.append( - 2 * np.pi - - phi[rank_good] - - center_lonlat[rank::ntask, 0][rank_good] - ) - else: - lon.append(phi[rank_good] - center_lonlat[rank::ntask, 0][rank_good]) - lat.append( - (np.pi / 2 - theta[rank_good]) - - center_lonlat[rank::ntask, 1][rank_good] - ) + dquats = qa.mult(bore_quats, detquat) + det_lon, det_lat = center_offset_lonlat( + dquats, + center_offset=center_lonlat, + degrees=False, + is_azimuth=is_azimuth, + ) + lon.append(det_lon) + lat.append(det_lat) + lon = np.unwrap(np.hstack(lon)) lat = np.hstack(lat) @@ -131,13 +178,6 @@ def scan_range_lonlat( latmin = np.amin(lat) latmax = np.amax(lat) - if lonmin < -2 * np.pi: - lonmin += 2 * np.pi - lonmax += 2 * np.pi - elif lonmax > 2 * np.pi: - lonmin -= 2 * np.pi - lonmax -= 2 * np.pi - # Combine results if obs.comm.comm_group is not None: lonlatmin = np.zeros(2, dtype=np.float64) diff --git a/src/toast/tests/ops_pointing_wcs.py b/src/toast/tests/ops_pointing_wcs.py index 2ac0533f1..14527da0b 100644 --- a/src/toast/tests/ops_pointing_wcs.py +++ b/src/toast/tests/ops_pointing_wcs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022 by the parties listed in the AUTHORS file. +# Copyright (c) 2021-2024 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. @@ -33,11 +33,17 @@ class PointingWCSTest(MPITestCase): def setUp(self): fixture_name = os.path.splitext(os.path.basename(__file__))[0] self.outdir = create_outdir(self.comm, fixture_name) + self.proj_dims = (1000, 500) # For debugging, change this to True self.write_extra = False - def check_hits(self, prefix, pixels): + def create_boresight_pointing(self, pixels): + # Given a fixed (not auto) wcs spec, simulate boresight pointing + if pixels.auto_bounds: + raise RuntimeError("Cannot use with auto bounds") + pixels.set_wcs() wcs = pixels.wcs + nlon, nlat = pixels.wcs_shape toastcomm = create_comm(self.comm) data = Data(toastcomm) @@ -46,108 +52,174 @@ def check_hits(self, prefix, pixels): sample_rate=1.0 * u.Hz, ) - # Make some fake boresight pointing - npix_ra = pixels.pix_ra - npix_dec = pixels.pix_dec px = list() - for ra in range(npix_ra): - px.extend( - np.column_stack( - [ - ra * np.ones(npix_dec), - np.arange(npix_dec), - ] - ).tolist() - ) - px = np.array(px, dtype=np.float64) - coord = wcs.wcs_pix2world(px, 0) - checkpx = wcs.wcs_world2pix(coord, 0) - coord *= np.pi / 180.0 + for plon in range(nlon): + for plat in range(nlat): + px.append([plon, plat]) + coord_deg = wcs.wcs_pix2world(np.array(px, dtype=np.float64), 0) + coord = np.radians(coord_deg) + phi = np.array(coord[:, 0], dtype=np.float64) half_pi = np.pi / 2 theta = np.array(half_pi - coord[:, 1], dtype=np.float64) bore = qa.from_iso_angles(theta, phi, np.zeros_like(theta)) - nsamp = npix_ra * npix_dec - data.obs.append(Observation(toastcomm, tele, n_samples=nsamp)) - data.obs[0].shared.create_column( - defaults.boresight_radec, (nsamp, 4), dtype=np.float64 - ) - data.obs[0].shared.create_column( - defaults.shared_flags, (nsamp,), dtype=np.uint8 - ) + nsamp = nlon * nlat + ob = Observation(toastcomm, tele, n_samples=nsamp) + ob.shared.create_column(defaults.boresight_radec, (nsamp, 4), dtype=np.float64) + ob.shared.create_column(defaults.shared_flags, (nsamp,), dtype=np.uint8) if toastcomm.group_rank == 0: - data.obs[0].shared[defaults.boresight_radec].set(bore) + ob.shared[defaults.boresight_radec].set(bore) else: - data.obs[0].shared[defaults.boresight_radec].set(None) + ob.shared[defaults.boresight_radec].set(None) + data.obs.append(ob) + return data + + def check_hits(self, prefix, pixels, data): + # Clear any existing pointing + for ob in data.obs: + if pixels.pixels in ob.detdata: + del ob.detdata[pixels.pixels] + if pixels.detector_pointing.quats in ob.detdata: + del ob.detdata[pixels.detector_pointing.quats] + # Pixel distribution + build_dist = ops.BuildPixelDistribution( + pixel_pointing=pixels, + ) + if build_dist.pixel_dist in data: + del data[build_dist.pixel_dist] + build_dist.apply(data) + + # Expand pointing pixels.apply(data) # Hitmap - build_hits = ops.BuildHitMap( - pixel_dist=pixels.create_dist, + pixel_dist=build_dist.pixel_dist, pixels=pixels.pixels, det_flags=None, ) + if build_hits.hits in data: + del data[build_hits.hits] build_hits.apply(data) if self.write_extra: outfile = os.path.join(self.outdir, f"{prefix}.fits") write_wcs_fits(data[build_hits.hits], outfile) + if data.comm.world_rank == 0: + plot_wcs_maps(hitfile=outfile) - if toastcomm.world_rank == 0: - set_matplotlib_backend() - - import matplotlib.pyplot as plt - - hdu = af.open(outfile)[0] - wcs = WCS(hdu.header) - - fig = plt.figure(figsize=(8, 8), dpi=100) - ax = fig.add_subplot(projection=wcs, slices=("x", "y", 0)) - # plt.imshow(hdu.data, vmin=-2.e-5, vmax=2.e-4, origin='lower') - im = ax.imshow( - np.transpose(hdu.data[0, :, :]), vmin=0, vmax=4, cmap="jet" - ) - ax.grid(color="white", ls="solid") - ax.set_xlabel("Longitude") - ax.set_ylabel("Latitude") - plt.colorbar(im, orientation="vertical") - fig.savefig(os.path.join(self.outdir, f"{prefix}.pdf"), format="pdf") - + flat_hits = data[build_hits.hits].data.flatten() + nonzero = flat_hits != 0 + hits_per_pixel = data.comm.ngroups * len(data.obs[0].all_detectors) + expected = np.zeros_like(flat_hits) + expected[nonzero] = hits_per_pixel np.testing.assert_array_equal( - data[build_hits.hits].data, - data.comm.ngroups - * len(data.obs[0].all_detectors) - * np.ones_like(data[build_hits.hits].data), + flat_hits, + expected, ) - close_data(data) + + def test_wcs(self): + return + # Test basic creation of WCS projections and plotting + res_deg = (0.01, 0.01) + dims = self.proj_dims + center_deg = (130.0, -30.0) + bounds_deg = (120.0, 140.0, -35.0, -25.0) + for proj in ["CAR", "TAN", "CEA", "MER", "ZEA", "SFL"]: + wcs, wcs_shape = ops.PixelsWCS.create_wcs( + coord="EQU", + proj=proj, + center_deg=None, + bounds_deg=bounds_deg, + res_deg=res_deg, + dims=None, + ) + if self.comm is None or self.comm.rank == 0: + pixdata = np.ones((1, wcs_shape[1], wcs_shape[0]), dtype=np.float32) + header = wcs.to_header() + hdu = af.PrimaryHDU(data=pixdata, header=header) + outfile = os.path.join(self.outdir, f"test_wcs_{proj}_bounds.fits") + hdu.writeto(outfile) + plot_wcs_maps(hitfile=outfile) + for proj in ["CAR", "TAN", "CEA", "MER", "ZEA", "SFL"]: + wcs, wcs_shape = ops.PixelsWCS.create_wcs( + coord="EQU", + proj=proj, + center_deg=center_deg, + bounds_deg=None, + res_deg=res_deg, + dims=dims, + ) + if self.comm is None or self.comm.rank == 0: + pixdata = np.ones((1, wcs_shape[1], wcs_shape[0]), dtype=np.float32) + header = wcs.to_header() + hdu = af.PrimaryHDU(data=pixdata, header=header) + outfile = os.path.join(self.outdir, f"test_wcs_{proj}_center.fits") + hdu.writeto(outfile) + plot_wcs_maps(hitfile=outfile) def test_projections(self): + return centers = list() - for lon in [130.0, 180.0, 230.0]: - for lat in [-40.0, 0.0, 40.0]: + for lon in [130.0, 180.0]: + for lat in [-40.0, 0.0]: centers.append((lon * u.degree, lat * u.degree)) detpointing_radec = ops.PointingDetectorSimple( boresight=defaults.boresight_radec ) - for proj in ["CAR", "TAN", "CEA", "MER", "ZEA"]: + # For each projection and center, run once and then change resolution + # test autoscaling. + + for proj in ["CAR", "TAN", "CEA", "MER", "ZEA", "SFL"]: for center in centers: pixels = ops.PixelsWCS( projection=proj, detector_pointing=detpointing_radec, - create_dist="dist", use_astropy=True, - center=center, - dimensions=(710, 350), - resolution=(0.1 * u.degree, 0.1 * u.degree), ) + # Verify that we can change the projection traits in various ways. + # First use non-auto_bounds to create one boresight pointing per + # pixel. + pixels.center = center + pixels.bounds = () + pixels.resolution = (0.02 * u.degree, 0.02 * u.degree) + pixels.dimensions = self.proj_dims + + data = self.create_boresight_pointing(pixels) self.check_hits( - f"hits_{proj}_{center[0].value}_{center[1].value}", pixels + f"hits_{proj}_0.02_{center[0].value}_{center[1].value}", + pixels, + data, ) + + self.assertFalse(pixels.auto_bounds) + self.assertTrue(pixels.center == center) + self.assertTrue(pixels.resolution == (0.02 * u.degree, 0.02 * u.degree)) + self.assertTrue(pixels.dimensions == self.proj_dims) + self.assertTrue(pixels.dimensions[0] == pixels.wcs_shape[0]) + self.assertTrue(pixels.dimensions[1] == pixels.wcs_shape[1]) + + # Note, increasing resolution will leave some pixels un-hit, but + # the check_hits() helper function will only check pixels with >0 hits + pixels.resolution = (0.01 * u.degree, 0.01 * u.degree) + pixels.center = () + pixels.dimensions = () + pixels.auto_bounds = True + + self.check_hits( + f"hits_{proj}_0.01_{center[0].value}_{center[1].value}_auto", + pixels, + data, + ) + + self.assertTrue(pixels.resolution == (0.01 * u.degree, 0.01 * u.degree)) + self.assertTrue(pixels.auto_bounds) + + close_data(data) if self.comm is not None: self.comm.barrier() @@ -159,7 +231,6 @@ def test_mapmaking(self): # Test several projections resolution = 0.1 * u.degree - # for proj in ["CAR", "TAN", "CEA", "MER", "ZEA"]: for proj in ["CAR"]: # Create fake observing of a small patch data = create_ground_data(self.comm) @@ -180,7 +251,8 @@ def test_mapmaking(self): pixels = ops.PixelsWCS( detector_pointing=detpointing_radec, projection=proj, - resolution=(0.5 * u.degree, 0.5 * u.degree), + resolution=(0.02 * u.degree, 0.02 * u.degree), + dimensions=(), auto_bounds=True, use_astropy=True, ) @@ -235,6 +307,7 @@ def test_mapmaking(self): pixel_pointing=pixels, stokes_weights=weights, noise_model=default_model.noise_model, + full_pointing=True, ) # Set up template matrix with just an offset template. @@ -255,37 +328,29 @@ def test_mapmaking(self): # Map maker mapper = ops.MapMaker( - name=f"test_{proj}", + name=f"mapmaking_{proj}", det_data=defaults.det_data, binning=binner, template_matrix=tmatrix, solve_rcond_threshold=1.0e-2, map_rcond_threshold=1.0e-2, - write_hits=False, - write_map=False, + write_hits=True, + write_map=True, write_cov=False, write_rcond=False, output_dir=self.outdir, - keep_solver_products=True, - keep_final_products=True, + keep_solver_products=False, + keep_final_products=False, ) if data.comm.comm_world is not None: data.comm.comm_world.barrier() mapper.apply(data) - if self.write_extra: - # Write outputs manually - for prod in ["hits", "map"]: - outfile = os.path.join(self.outdir, f"mapmaking_{proj}_{prod}.fits") - write_wcs_fits(data[f"{mapper.name}_{prod}"], outfile) - - if rank == 0: - outfile = os.path.join(self.outdir, f"mapmaking_{proj}_hits.fits") - plot_wcs_maps(hitfile=outfile) - - outfile = os.path.join(self.outdir, f"mapmaking_{proj}_map.fits") - plot_wcs_maps(mapfile=outfile) + if rank == 0: + hitfile = os.path.join(self.outdir, f"mapmaking_{proj}_hits.fits") + mapfile = os.path.join(self.outdir, f"mapmaking_{proj}_map.fits") + plot_wcs_maps(hitfile=hitfile, mapfile=mapfile) close_data(data) @@ -297,19 +362,45 @@ def fake_source(self, mission_start, ra_start, dec_start, times, deg_per_hour=1. incr = (times - t_start) * deg_sec return first_ra + incr, first_dec + incr + def fake_drone(self, mission_start, az_target, el_target, times, deg_amplitude=1.0): + # Just simulate moving in a circle around the target location + t_start = float(times[0]) + t_off = t_start - mission_start + n_samp = len(times) + ang = (2 * np.pi / n_samp) * np.arange(n_samp) + az = az_target + deg_amplitude * np.cos(ang) + el = el_target + deg_amplitude * np.sin(ang) + return az, el + def create_source_data( - self, data, proj, res, signal_name, deg_per_hour=1.0, dbg_dir=None + self, + data, + proj, + res, + signal_name, + deg_per_hour=1.0, + azel=False, + deg_amplitude=1.0, + dbg_dir=None, ): - detpointing = ops.PointingDetectorSimple( - boresight=defaults.boresight_radec, - quats="temp_quats", - ) - detpointing.apply(data) + if azel: + detpointing = ops.PointingDetectorSimple( + boresight=defaults.boresight_azel, + quats="temp_quats", + ) + detpointing.apply(data) + else: + detpointing = ops.PointingDetectorSimple( + boresight=defaults.boresight_radec, + quats="temp_quats", + ) + detpointing.apply(data) # Normal autoscaled projection pixels = ops.PixelsWCS( projection=proj, resolution=(res, res), + dimensions=(), detector_pointing=detpointing, pixels="temp_pix", use_astropy=True, @@ -356,32 +447,42 @@ def create_source_data( ) scanner.apply(data) - # Use this overall projection window to determine our source - # movement. The source starts at the center of the projection. - px = np.array( - [ - [ - int(0.6 * pixels.pix_ra), - int(0.2 * pixels.pix_dec), - ], - ], - dtype=np.float64, - ) + if azel: + # Simulating a drone near the center + px = np.array( + [[int(0.5 * pixels.pix_lat), int(0.5 * pixels.pix_lon)]], + dtype=np.float64, + ) + else: + # Use this overall projection window to determine our source + # movement. The source starts at the center of the projection. + px = np.array( + [[int(0.6 * pixels.pix_lat), int(0.2 * pixels.pix_lon)]], + dtype=np.float64, + ) source_start = pixels.wcs.wcs_pix2world(px, 0) # Create the fake ephemeris data and accumulate to signal. for ob in data.obs: n_samp = ob.n_local_samples times = np.array(ob.shared[defaults.times].data) - - source_ra, source_dec = self.fake_source( - data.obs[0].shared["times"][0], - source_start[0][0], - source_start[0][1], - times, - deg_per_hour=deg_per_hour, - ) - source_coord = np.column_stack([source_ra, source_dec]) + if azel: + source_lon, source_lat = self.fake_drone( + data.obs[0].shared["times"][0], + source_start[0][0], + source_start[0][1], + times, + deg_amplitude=deg_amplitude, + ) + else: + source_lon, source_lat = self.fake_source( + data.obs[0].shared["times"][0], + source_start[0][0], + source_start[0][1], + times, + deg_per_hour=deg_per_hour, + ) + source_coord = np.column_stack([source_lon, source_lat]) # Create a shared data object with the fake source location ob.shared.create_column("source", (n_samp, 2), dtype=np.float64) @@ -405,7 +506,7 @@ def create_source_data( sdist_arc = sdist * 180.0 * 60.0 / np.pi seen = sdist_arc < 10 seen_samp = np.arange(len(sdist), dtype=np.int32) - amp = 50.0 * coeff * np.exp(pre * np.square(sdist_arc)) + amp = 10.0 * coeff * np.exp(pre * np.square(sdist_arc)) ob.detdata[signal_name][det, :] += amp[:] if dbg_dir is not None and ob.comm.group_rank == 0: @@ -459,10 +560,9 @@ def create_source_data( ) mapper.apply(data) if data.comm.world_rank == 0: - outfile = os.path.join(self.outdir, f"source_{proj}_notrack_hits.fits") - plot_wcs_maps(hitfile=outfile) - outfile = os.path.join(self.outdir, f"source_{proj}_notrack_map.fits") - plot_wcs_maps(mapfile=outfile) + hitfile = os.path.join(self.outdir, f"source_{proj}_notrack_hits.fits") + mapfile = os.path.join(self.outdir, f"source_{proj}_notrack_map.fits") + plot_wcs_maps(hitfile=hitfile, mapfile=mapfile) # Cleanup our temp objects ops.Delete( @@ -476,15 +576,15 @@ def test_source_map(self): rank = self.comm.rank # Test several projections - resolution = 0.5 * u.degree + resolution = 0.02 * u.degree - for proj in ["CAR", "TAN"]: + for proj in ["TAN"]: # Create fake observing of a small patch data = create_ground_data(self.comm, pixel_per_process=10) # Create source motion and simulated detector data. dbgdir = None - if proj == "CAR" and self.write_extra: + if self.write_extra: dbgdir = self.outdir self.create_source_data( data, proj, resolution, defaults.det_data, dbg_dir=dbgdir @@ -506,6 +606,7 @@ def test_source_map(self): pixels = ops.PixelsWCS( projection=proj, resolution=(resolution, resolution), + dimensions=(), center_offset="source", detector_pointing=detpointing_radec, use_astropy=True, @@ -527,6 +628,7 @@ def test_source_map(self): pixel_pointing=pixels, stokes_weights=weights, noise_model=default_model.noise_model, + full_pointing=True, ) # Set up template matrix with just an offset template. @@ -552,7 +654,6 @@ def test_source_map(self): det_data=defaults.det_data, solve_rcond_threshold=1.0e-2, map_rcond_threshold=1.0e-2, - iter_max=10, binning=binner, template_matrix=tmatrix, output_dir=self.outdir, @@ -563,11 +664,122 @@ def test_source_map(self): mapper.apply(data) if rank == 0: - outfile = os.path.join(self.outdir, f"source_{proj}_hits.fits") - plot_wcs_maps(hitfile=outfile) - outfile = os.path.join(self.outdir, f"source_{proj}_map.fits") - plot_wcs_maps(mapfile=outfile) - outfile = os.path.join(self.outdir, f"source_{proj}_binmap.fits") - plot_wcs_maps(mapfile=outfile) + hitfile = os.path.join(self.outdir, f"source_{proj}_hits.fits") + mapfile = os.path.join(self.outdir, f"source_{proj}_map.fits") + binmapfile = os.path.join(self.outdir, f"source_{proj}_binmap.fits") + plot_wcs_maps(hitfile=hitfile, mapfile=mapfile) + plot_wcs_maps(hitfile=hitfile, mapfile=binmapfile) + + close_data(data) + + def test_drone_map(self): + rank = 0 + if self.comm is not None: + rank = self.comm.rank + + # Test several projections + resolution = 0.02 * u.degree + + for proj in ["SFL"]: + # Create fake observing of a small patch + data = create_ground_data(self.comm, pixel_per_process=10) + + # We are going to hack the boresight pointing so that the RA/DEC simulated + # pointing is treated as Az/El. This means that the scan pattern will not + # be realistic, but at least should cover the source + for obs in data.obs: + if obs.comm_col_rank == 0: + obs.shared["boresight_azel"].data[:, :] = obs.shared[ + "boresight_radec" + ].data[:, :] + + # Create source motion and simulated detector data. + dbgdir = None + if self.write_extra: + dbgdir = self.outdir + self.create_source_data( + data, proj, resolution, defaults.det_data, azel=True, dbg_dir=dbgdir + ) + + # Simple detector pointing + detpointing_azel = ops.PointingDetectorSimple( + boresight=defaults.boresight_azel, + ) + + # Stokes weights + weights = ops.StokesWeights( + mode="IQU", + hwp_angle=defaults.hwp_angle, + detector_pointing=detpointing_azel, + ) + + # Source-centered pointing + pixels = ops.PixelsWCS( + coord_frame="AZEL", + projection=proj, + resolution=(resolution, resolution), + dimensions=(), + center_offset="source", + detector_pointing=detpointing_azel, + use_astropy=True, + auto_bounds=True, + ) + + pix_dist = ops.BuildPixelDistribution( + pixel_dist="pixel_dist", + pixel_pointing=pixels, + ) + pix_dist.apply(data) + + default_model = ops.DefaultNoiseModel(noise_model="noise_model") + default_model.apply(data) + + # Set up binning operator for solving + binner = ops.BinMap( + pixel_dist="pixel_dist", + pixel_pointing=pixels, + stokes_weights=weights, + noise_model=default_model.noise_model, + full_pointing=True, + ) + + # Set up template matrix with just an offset template. + + # Use 1/10 of an observation as the baseline length. Make it not evenly + # divisible in order to test handling of the final amplitude. + ob_time = ( + data.obs[0].shared[defaults.times][-1] + - data.obs[0].shared[defaults.times][0] + ) + step_seconds = float(int(ob_time / 10.0)) + tmpl = templates.Offset( + times=defaults.times, + det_flags=None, + noise_model=default_model.noise_model, + step_time=step_seconds * u.second, + ) + tmatrix = ops.TemplateMatrix(templates=[tmpl]) + + # Map maker + mapper = ops.MapMaker( + name=f"drone_{proj}", + det_data=defaults.det_data, + solve_rcond_threshold=1.0e-2, + map_rcond_threshold=1.0e-2, + binning=binner, + template_matrix=tmatrix, + output_dir=self.outdir, + write_hits=True, + write_map=True, + write_binmap=True, + ) + mapper.apply(data) + + if rank == 0: + hitfile = os.path.join(self.outdir, f"drone_{proj}_hits.fits") + mapfile = os.path.join(self.outdir, f"drone_{proj}_map.fits") + binmapfile = os.path.join(self.outdir, f"drone_{proj}_binmap.fits") + plot_wcs_maps(hitfile=hitfile, mapfile=mapfile) + plot_wcs_maps(hitfile=hitfile, mapfile=binmapfile) close_data(data) diff --git a/src/toast/vis.py b/src/toast/vis.py index 7c8d9b7db..911affed4 100644 --- a/src/toast/vis.py +++ b/src/toast/vis.py @@ -125,6 +125,7 @@ def plot_wcs_maps( xmax=None, ymin=None, ymax=None, + cmap="viridis", ): """Plot WCS projected output maps. @@ -142,22 +143,35 @@ def plot_wcs_maps( xmax (float): Fraction (0.0-1.0) of the maximum X view. ymin (float): Fraction (0.0-1.0) of the minimum Y view. ymin (float): Fraction (0.0-1.0) of the maximum Y view. + cmap (str): The color map name to use. """ + import matplotlib as mpl import matplotlib.pyplot as plt - figsize = (12, 12) figdpi = 100 + current_cmap = mpl.cm.get_cmap(cmap) + current_cmap.set_bad(color="gray") + def plot_single(wcs, hdata, hindx, vmin, vmax, out): + xwcs = wcs.pixel_shape[0] + ywcs = wcs.pixel_shape[1] + fig_x = xwcs / figdpi + fig_y = ywcs / figdpi + figsize = (fig_x, fig_y) fig = plt.figure(figsize=figsize, dpi=figdpi) ax = fig.add_subplot(projection=wcs, slices=("x", "y", hindx)) im = ax.imshow( - np.transpose(hdata.data[hindx, :, :]), cmap="jet", vmin=vmin, vmax=vmax + hdata[hindx, :, :], + cmap=current_cmap, + vmin=vmin, + vmax=vmax, + interpolation="nearest", ) ax.grid(color="white", ls="solid") - ax.set_xlabel("Longitude") - ax.set_ylabel("Latitude") + ax.set_xlabel(f"{wcs.wcs.ctype[0]}") + ax.set_ylabel(f"{wcs.wcs.ctype[1]}") if xmin is not None and xmax is not None: ax.set_xlim(xmin, xmax) if ymin is not None and ymax is not None: @@ -181,25 +195,30 @@ def sym_range(hdata): ext = max(np.absolute(minval), np.absolute(maxval)) return -ext, ext - def sub_mono(hitdata, mdata): - if hitdata is None: + def flag_unhit(hitmask, mdata): + if hitmask is None: return - goodpix = np.logical_and((hitdata > 0), (mdata != 0)) + for mindx in range(mdata.shape[0]): + mdata[mindx, hitmask] = np.nan + + def sub_mono(hitmask, mdata): + if hitmask is None: + goodpix = mdata != 0 + else: + goodpix = np.logical_and(hitmask, (mdata != 0)) mono = np.mean(mdata[goodpix]) - print(f"Monopole = {mono}") mdata[goodpix] -= mono - mdata[np.logical_not(goodpix)] = 0 - hitdata = None + hitmask = None if hitfile is not None: hdulist = af.open(hitfile) hdu = hdulist[0] - hitdata = np.array(hdu.data[0, :, :]) + hitmask = np.array(hdu.data[0, :, :] == 0) wcs = WCS(hdu.header) - maxhits = np.amax(hdu.data[0, :, :]) + maxhits = 0.5 * np.amax(hdu.data[0, :, :]) if max_hits is not None: maxhits = max_hits - plot_single(wcs, hdu, 0, 0, maxhits, f"{hitfile}.pdf") + plot_single(wcs, hdu.data, 0, 0, maxhits, f"{hitfile}.pdf") del hdu hdulist.close() @@ -207,44 +226,48 @@ def sub_mono(hitdata, mdata): hdulist = af.open(mapfile) hdu = hdulist[0] wcs = WCS(hdu.header) + mapdata = np.array(hdu.data) + del hdu if truth is not None: thdulist = af.open(truth) thdu = thdulist[0] - sub_mono(hitdata, hdu.data[0, :, :]) - mmin, mmax = sym_range(hdu.data[0, :, :]) + flag_unhit(hitmask, mapdata) + + sub_mono(hitmask, mapdata[0]) + mmin, mmax = sym_range(mapdata[0, :, :]) if range_I is not None: mmin, mmax = range_I - plot_single(wcs, hdu, 0, mmin, mmax, f"{mapfile}_I.pdf") + plot_single(wcs, mapdata, 0, mmin, mmax, f"{mapfile}_I.pdf") if truth is not None: tmin, tmax = sym_range(thdu.data[0, :, :]) - hdu.data[0, :, :] -= thdu.data[0, :, :] - plot_single(wcs, hdu, 0, tmin, tmax, f"{mapfile}_resid_I.pdf") + mapdata[0, :, :] -= thdu.data[0, :, :] + plot_single(wcs, mapdata, 0, tmin, tmax, f"{mapfile}_resid_I.pdf") - if hdu.data.shape[0] > 1: - mmin, mmax = sym_range(hdu.data[1, :, :]) + if mapdata.shape[0] > 1: + mmin, mmax = sym_range(mapdata[1, :, :]) if range_Q is not None: mmin, mmax = range_Q - plot_single(wcs, hdu, 1, mmin, mmax, f"{mapfile}_Q.pdf") + plot_single(wcs, mapdata, 1, mmin, mmax, f"{mapfile}_Q.pdf") if truth is not None: tmin, tmax = sym_range(thdu.data[1, :, :]) - hdu.data[1, :, :] -= thdu.data[1, :, :] - plot_single(wcs, hdu, 1, tmin, tmax, f"{mapfile}_resid_Q.pdf") + mapdata[1, :, :] -= thdu.data[1, :, :] + plot_single(wcs, mapdata, 1, tmin, tmax, f"{mapfile}_resid_Q.pdf") - mmin, mmax = sym_range(hdu.data[2, :, :]) + mmin, mmax = sym_range(mapdata[2, :, :]) if range_U is not None: mmin, mmax = range_U - plot_single(wcs, hdu, 2, mmin, mmax, f"{mapfile}_U.pdf") + plot_single(wcs, mapdata, 2, mmin, mmax, f"{mapfile}_U.pdf") if truth is not None: tmin, tmax = sym_range(thdu.data[2, :, :]) - hdu.data[2, :, :] -= thdu.data[2, :, :] - plot_single(wcs, hdu, 2, tmin, tmax, f"{mapfile}_resid_U.pdf") + mapdata[2, :, :] -= thdu.data[2, :, :] + plot_single(wcs, mapdata, 2, tmin, tmax, f"{mapfile}_resid_U.pdf") if truth is not None: del thdu thdulist.close() - del hdu + hdulist.close() @@ -463,7 +486,6 @@ def sym_range(data): gnomres *= 60 if gnomview: gnomrot = (mlon, mlat, 0.0) - print(f"gnomres = {gnomres} arcmin, gnomrot = {gnomrot}", flush=True) plot_single( hitdata, 0, From 65e56798407817446cb8af074945c126022fe3c1 Mon Sep 17 00:00:00 2001 From: Theodore Kisner Date: Fri, 10 May 2024 23:44:48 -0700 Subject: [PATCH 2/5] Fix other unit tests --- src/toast/tests/ops_scan_wcs.py | 2 ++ src/toast/tests/template_hwpss.py | 2 ++ src/toast/tests/template_periodic.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/src/toast/tests/ops_scan_wcs.py b/src/toast/tests/ops_scan_wcs.py index 6b054224b..99da89f9b 100644 --- a/src/toast/tests/ops_scan_wcs.py +++ b/src/toast/tests/ops_scan_wcs.py @@ -41,6 +41,7 @@ def test_wcs_fits(self): pixels = ops.PixelsWCS( projection="CAR", resolution=(0.05 * u.degree, 0.05 * u.degree), + dimensions=(), auto_bounds=True, detector_pointing=detpointing_radec, create_dist="pixel_dist", @@ -105,6 +106,7 @@ def test_wcs_mask(self): pixels = ops.PixelsWCS( projection="CAR", resolution=(0.05 * u.degree, 0.05 * u.degree), + dimensions=(), auto_bounds=True, detector_pointing=detpointing_radec, create_dist="pixel_dist", diff --git a/src/toast/tests/template_hwpss.py b/src/toast/tests/template_hwpss.py index 23e9107c9..4307608fe 100644 --- a/src/toast/tests/template_hwpss.py +++ b/src/toast/tests/template_hwpss.py @@ -453,6 +453,7 @@ def create_ground_sim(self, outdir, width, sky_proj, sky_res): resolution=(sky_res, sky_res), detector_pointing=detpointing, pixels="temp_pix", + dimensions=(), use_astropy=True, auto_bounds=True, ) @@ -782,6 +783,7 @@ def test_ground_hwp_narrow(self): detector_pointing=detpointing_radec, projection=proj, resolution=(res, res), + dimensions=(), auto_bounds=True, use_astropy=True, ) diff --git a/src/toast/tests/template_periodic.py b/src/toast/tests/template_periodic.py index acad6fa13..746bead81 100644 --- a/src/toast/tests/template_periodic.py +++ b/src/toast/tests/template_periodic.py @@ -295,6 +295,7 @@ def create_ground_sim(self, outdir, width, sky_proj, sky_res): pixels = ops.PixelsWCS( projection=sky_proj, resolution=(sky_res, sky_res), + dimensions=(), detector_pointing=detpointing, pixels="temp_pix", use_astropy=True, @@ -618,6 +619,7 @@ def test_ground_hwp_narrow(self): detector_pointing=detpointing_radec, projection=proj, resolution=(res, res), + dimensions=(), auto_bounds=True, use_astropy=True, ) From 8d272d4e7ffdf60554e6896d6ea7bfec720f9664 Mon Sep 17 00:00:00 2001 From: Theodore Kisner Date: Mon, 13 May 2024 06:44:07 -0700 Subject: [PATCH 3/5] For local Az/El coordinate frame, use TLON/TLAT in the CTYPE. Thanks to @gabrielecoppi for the suggestion. --- src/toast/ops/pixels_wcs.py | 8 +++----- src/toast/scripts/toast_plot_wcs.py | 18 ++++++++++++++++++ src/toast/tests/ops_pointing_wcs.py | 4 ++-- src/toast/vis.py | 4 ++++ 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/toast/ops/pixels_wcs.py b/src/toast/ops/pixels_wcs.py index 97a68d62f..82d7091e7 100644 --- a/src/toast/ops/pixels_wcs.py +++ b/src/toast/ops/pixels_wcs.py @@ -256,11 +256,9 @@ def create_wcs( wcs = WCS(naxis=2) if coord == "AZEL": - # FIXME: The WCS standard does not define a keyword for - # horizontal coordinates. How should we deal with this? - # Also AZ is reversed from normal conventions- should we - # negate CDELT? - coordstr = ("RA--", "DEC-") + # For local Azimuth and Elevation coordinate frame, we + # use the generic longitude and latitude string. + coordstr = ("TLON", "TLAT") elif coord == "EQU": coordstr = ("RA--", "DEC-") elif coord == "GAL": diff --git a/src/toast/scripts/toast_plot_wcs.py b/src/toast/scripts/toast_plot_wcs.py index afdbb7e20..6a0c1c560 100644 --- a/src/toast/scripts/toast_plot_wcs.py +++ b/src/toast/scripts/toast_plot_wcs.py @@ -124,6 +124,22 @@ def main(): help="Maximum Y viewport fraction (0.0 - 1.0)", ) + parser.add_argument( + "--cmap", + required=False, + type=str, + default="viridis", + help="The colormap name (e.g. 'inferno')", + ) + + parser.add_argument( + "--azimuth", + required=False, + default=False, + action="store_true", + help="Data is Azimuth / Elevation, so invert the X-axis", + ) + args = parser.parse_args() range_I = None @@ -148,6 +164,8 @@ def main(): xmax=args.Xmax, ymin=args.Ymin, ymax=args.Ymax, + is_azimuth=args.azimuth, + cmap=args.cmap, ) diff --git a/src/toast/tests/ops_pointing_wcs.py b/src/toast/tests/ops_pointing_wcs.py index 14527da0b..dbfa88d0d 100644 --- a/src/toast/tests/ops_pointing_wcs.py +++ b/src/toast/tests/ops_pointing_wcs.py @@ -779,7 +779,7 @@ def test_drone_map(self): hitfile = os.path.join(self.outdir, f"drone_{proj}_hits.fits") mapfile = os.path.join(self.outdir, f"drone_{proj}_map.fits") binmapfile = os.path.join(self.outdir, f"drone_{proj}_binmap.fits") - plot_wcs_maps(hitfile=hitfile, mapfile=mapfile) - plot_wcs_maps(hitfile=hitfile, mapfile=binmapfile) + plot_wcs_maps(hitfile=hitfile, mapfile=mapfile, is_azimuth=True) + plot_wcs_maps(hitfile=hitfile, mapfile=binmapfile, is_azimuth=True) close_data(data) diff --git a/src/toast/vis.py b/src/toast/vis.py index 911affed4..2508dfd00 100644 --- a/src/toast/vis.py +++ b/src/toast/vis.py @@ -125,6 +125,7 @@ def plot_wcs_maps( xmax=None, ymin=None, ymax=None, + is_azimuth=False, cmap="viridis", ): """Plot WCS projected output maps. @@ -143,6 +144,7 @@ def plot_wcs_maps( xmax (float): Fraction (0.0-1.0) of the maximum X view. ymin (float): Fraction (0.0-1.0) of the minimum Y view. ymin (float): Fraction (0.0-1.0) of the maximum Y view. + is_azimuth (bool): If True, swap direction of longitude axis. cmap (str): The color map name to use. """ @@ -169,6 +171,8 @@ def plot_single(wcs, hdata, hindx, vmin, vmax, out): vmax=vmax, interpolation="nearest", ) + if is_azimuth: + ax.invert_xaxis() ax.grid(color="white", ls="solid") ax.set_xlabel(f"{wcs.wcs.ctype[0]}") ax.set_ylabel(f"{wcs.wcs.ctype[1]}") From 9fab690faceabb561ee6d3fa3096bf14ac7e1914 Mon Sep 17 00:00:00 2001 From: Theodore Kisner Date: Mon, 13 May 2024 12:10:35 -0700 Subject: [PATCH 4/5] PixelDistribution global2local is now always created --- src/toast/pixels.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/toast/pixels.py b/src/toast/pixels.py index 3107dbfed..384ada990 100644 --- a/src/toast/pixels.py +++ b/src/toast/pixels.py @@ -191,12 +191,7 @@ def global_pixel_to_submap(self, gl): msg = "Global pixel indices exceed the maximum for the pixelization" log.error(msg) raise RuntimeError(msg) - if self._glob2loc is None: - msg = "PixelDistribution: no local submaps defined" - log.error(msg) - raise RuntimeError(msg) - else: - return libtoast_global_to_local(gl, self._n_pix_submap, self._glob2loc) + libtoast_global_to_local(gl, self._n_pix_submap, self._glob2loc) @function_timer def global_pixel_to_local(self, gl): From 8ad6e46e3ea883b426a9443a59bfce476fbbf554 Mon Sep 17 00:00:00 2001 From: Theodore Kisner Date: Mon, 13 May 2024 12:54:13 -0700 Subject: [PATCH 5/5] Fix typo --- src/toast/pixels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/toast/pixels.py b/src/toast/pixels.py index 384ada990..5224a01f8 100644 --- a/src/toast/pixels.py +++ b/src/toast/pixels.py @@ -191,7 +191,7 @@ def global_pixel_to_submap(self, gl): msg = "Global pixel indices exceed the maximum for the pixelization" log.error(msg) raise RuntimeError(msg) - libtoast_global_to_local(gl, self._n_pix_submap, self._glob2loc) + return libtoast_global_to_local(gl, self._n_pix_submap, self._glob2loc) @function_timer def global_pixel_to_local(self, gl):