Skip to content

Commit

Permalink
Refactor of WCS pixelization operator (#757)
Browse files Browse the repository at this point in the history
* Refactor of WCS pixelization operator

- Move the application of source centering in the projection to a
  separate helper function in `pointing_utils.py`.  Thanks to
  @gabrielecoppi for identifying this fix.  Optionally use this
  new function when computing the scan range for autoscaling.

- In `PixelsWCS`:

  - Add a new general class method that computes the WCS parameters.
  - Add support for SFL projection.
  - Allow projection traits to be changed in any order and only
    recompute the WCS if needed when exec() is called.
  - Default to a single submap, which is the most efficient choice
    for the common case of data distributed by detector and many
    observations co-incident on the sky.

- In the PixelsWCS unit tests:

  - Ensure projection and plotting works for every supported
    projection type with both fixed parameters and autoscaling.
  - Test mapmaking in both normal mode and with source-centered
    projections in RA/DEC and Az/El.

- In `plot_wcs_maps`:

  - Set the figure size based on the DPI and the
    actual size of the image in pixels.
  - Set the unhit pixels to gray.
  - Allow specifying the color map, and default to one of the
    perceptially uniform ones.

* Fix other unit tests

* For local Az/El coordinate frame, use TLON/TLAT in the CTYPE.  Thanks to @gabrielecoppi for the suggestion.

* PixelDistribution global2local is now always created

* Fix typo
  • Loading branch information
tskisner authored May 13, 2024
1 parent bf711e8 commit b1540cb
Show file tree
Hide file tree
Showing 10 changed files with 780 additions and 416 deletions.
452 changes: 259 additions & 193 deletions src/toast/ops/pixels_wcs.py

Large diffs are not rendered by default.

12 changes: 3 additions & 9 deletions src/toast/pixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,14 @@ def __init__(self, n_pix=None, n_submap=1000, local_submaps=None, comm=None):
self._local_submaps = local_submaps
self._comm = comm

self._glob2loc = None
self._n_local = 0
self._glob2loc = AlignedI64.zeros(self._n_submap)
self._glob2loc[:] = -1

if self._local_submaps is not None and len(self._local_submaps) > 0:
if np.max(self._local_submaps) > self._n_submap - 1:
raise RuntimeError("local submap indices out of range")
self._n_local = len(self._local_submaps)
self._glob2loc = AlignedI64.zeros(self._n_submap)
self._glob2loc[:] = -1
for ilocal_submap, iglobal_submap in enumerate(self._local_submaps):
self._glob2loc[iglobal_submap] = ilocal_submap

Expand Down Expand Up @@ -192,12 +191,7 @@ def global_pixel_to_submap(self, gl):
msg = "Global pixel indices exceed the maximum for the pixelization"
log.error(msg)
raise RuntimeError(msg)
if self._glob2loc is None:
msg = "PixelDistribution: no local submaps defined"
log.error(msg)
raise RuntimeError(msg)
else:
return libtoast_global_to_local(gl, self._n_pix_submap, self._glob2loc)
return libtoast_global_to_local(gl, self._n_pix_submap, self._glob2loc)

@function_timer
def global_pixel_to_local(self, gl):
Expand Down
18 changes: 10 additions & 8 deletions src/toast/pixels_io_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def submap_to_image(dist, submap, sdata, image):
"""
imshape = image.shape
n_value = imshape[0]
n_cols = imshape[1]
n_rows = imshape[2]
n_rows = imshape[1]
n_cols = imshape[2]

# Global pixel range of this submap
s_offset = submap * dist.n_pix_submap
Expand All @@ -64,7 +64,7 @@ def submap_to_image(dist, submap, sdata, image):
if pix_offset + n_copy > s_end:
n_copy = s_end - pix_offset
sbuf_offset = pix_offset + row_offset - s_offset
image[ival, col, row_offset : row_offset + n_copy] = sdata[
image[ival, row_offset : row_offset + n_copy, col] = sdata[
sbuf_offset : sbuf_offset + n_copy, ival
]

Expand All @@ -88,8 +88,8 @@ def image_to_submap(dist, image, submap, sdata, scale=1.0):
"""
imshape = image.shape
n_value = imshape[0]
n_cols = imshape[1]
n_rows = imshape[2]
n_rows = imshape[1]
n_cols = imshape[2]

# Global pixel range of this submap
s_offset = submap * dist.n_pix_submap
Expand All @@ -113,7 +113,7 @@ def image_to_submap(dist, image, submap, sdata, scale=1.0):
n_copy = s_end - pix_offset
sbuf_offset = pix_offset + row_offset - s_offset
sdata[sbuf_offset : sbuf_offset + n_copy, ival] = (
scale * image[ival, col, row_offset : row_offset + n_copy]
scale * image[ival, row_offset : row_offset + n_copy, col]
)


Expand Down Expand Up @@ -153,10 +153,12 @@ def collect_wcs_submaps(pix, comm_bytes=10000000):
allowners = np.zeros_like(owners)
dist.comm.Allreduce(owners, allowners, op=MPI.MIN)

# Create an image array for the output
# Create an image array for the output. The FITS image data is column
# major, and so our numpy array has the order of axes swapped.
image = None
image_shape = (pix.n_value, dist.wcs_shape[1], dist.wcs_shape[0])
if rank == 0:
image = np.zeros((pix.n_value,) + dist.wcs_shape, dtype=pix.dtype)
image = np.zeros(image_shape, dtype=pix.dtype)

n_val_submap = dist.n_pix_submap * pix.n_value

Expand Down
148 changes: 94 additions & 54 deletions src/toast/pointing_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,72 @@
# Copyright (c) 2015-2023 by the parties listed in the AUTHORS file.
# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.

# Pointing utility functions used by templates and operators


import numpy as np
from astropy import units as u

from . import qarray as qa
from .instrument_coords import quat_to_xieta
from .mpi import MPI
from .timing import GlobalTimers, Timer, function_timer


def center_offset_lonlat(
quats,
center_offset=None,
degrees=False,
is_azimuth=False,
):
"""Compute relative longitude / latitude from a dynamic center position.
Args:
quats (array): Input pointing quaternions
center_offset (array): Center longitude, latitude in radians for each sample
degrees (bool): If True, return longitude / latitude values in degrees
is_azimuth (bool): If True, we are using azimuth and the sign of the
longitude values should be negated
Returns:
(tuple): The (longitude, latitude) arrays
"""
if center_offset is None:
lon_rad, lat_rad, _ = qa.to_lonlat_angles(quats)
else:
if len(quats.shape) == 2:
n_samp = quats.shape[0]
else:
n_samp = 1
if center_offset.shape[0] != n_samp:
msg = f"center_offset dimensions {center_offset.shape}"
msg += f" not compatible with {n_samp} quaternion values"
raise ValueError(msg)
q_center = qa.from_lonlat_angles(
center_offset[:, 0],
center_offset[:, 1],
np.zeros_like(center_offset[:, 0]),
)
q_final = qa.mult(qa.inv(q_center), quats)
lon_rad, lat_rad, _ = quat_to_xieta(q_final)
if is_azimuth:
lon_rad = 2 * np.pi - lon_rad
# Normalize range
shift = lon_rad >= 2 * np.pi
lon_rad[shift] -= 2 * np.pi
shift = lon_rad < 0
lon_rad[shift] += 2 * np.pi
# Convert units
if degrees:
lon = np.degrees(lon_rad)
lat = np.degrees(lat_rad)
else:
lon = lon_rad
lat = lat_rad
return (lon, lat)


@function_timer
def scan_range_lonlat(
obs,
Expand Down Expand Up @@ -50,20 +104,36 @@ def scan_range_lonlat(
fov = obs.telescope.focalplane.field_of_view
fp_radius = 0.5 * fov.to_value(u.radian)

# The observation samples we are considering
if samples is None:
slc = slice(0, obs.n_local_samples, 1)
else:
slc = samples

# Get the flags if needed.
fdata = None
# Apply the flags to boresight pointing if needed.
bore_quats = np.array(obs.shared[boresight].data[slc, :])
if flags is not None:
fdata = np.array(obs.shared[flags][slc])
fdata = np.array(obs.shared[flags].data[slc])
fdata &= flag_mask
bore_quats = bore_quats[fdata == 0, :]

# work in parallel
# The remaining good samples we have left
n_good = bore_quats.shape[0]

# Check that the top of the focalplane is below the zenith
_, el_bore, _ = qa.to_lonlat_angles(bore_quats)
elmax_bore = np.amax(el_bore)
if elmax_bore + fp_radius > np.pi / 2:
msg = f"The scan range includes the zenith."
msg += f" Max boresight elevation is {np.degrees(elmax_bore)} deg"
msg += f" and focalplane radius is {np.degrees(fp_radius)} deg."
msg += " Scan range facility cannot handle this case."
raise RuntimeError(msg)

# Work in parallel
rank = obs.comm.group_rank
ntask = obs.comm.group_size
rank_slice = slice(rank, n_good, ntask)

# Create a fake focalplane of detectors in a circle around the boresight
xaxis, yaxis, zaxis = np.eye(3)
Expand All @@ -76,52 +146,29 @@ def scan_range_lonlat(
detquat = qa.mult(phirot, thetarot)
detquats.append(detquat)

# Get fake detector pointing

# Get source center positions if needed
center_lonlat = None
if center_offset is not None:
center_lonlat = np.array(obs.shared[center_offset][slc, :])
center_lonlat = np.array(
(obs.shared[center_offset].data[slc, :])[rank_slice, :]
)
# center_offset is in degrees
center_lonlat[:, :] *= np.pi / 180.0

lon = []
lat = []
quats = obs.shared[boresight][slc, :][rank::ntask].copy()
rank_good = slice(None)
if fdata is not None:
rank_good = fdata[rank::ntask] == 0

# Check that the top of the focalplane is below the zenith
theta_bore, _, _ = qa.to_iso_angles(quats)
el_bore = np.pi / 2 - theta_bore[rank_good]
elmax_bore = np.amax(el_bore)
if elmax_bore + fp_radius > np.pi / 2:
msg = f"The scan range includes the zenith."
msg += f" Max boresight elevation is {np.degrees(elmax_bore)} deg"
msg += f" and focalplane radius is {np.degrees(fp_radius)} deg."
msg += " Scan range facility cannot handle this case."
raise RuntimeError(msg)

# Compute pointing of fake detectors
lon = list()
lat = list()
for idet, detquat in enumerate(detquats):
theta, phi, _ = qa.to_iso_angles(qa.mult(quats, detquat))
if center_lonlat is None:
if is_azimuth:
lon.append(2 * np.pi - phi[rank_good])
else:
lon.append(phi[rank_good])
lat.append(np.pi / 2 - theta[rank_good])
else:
if is_azimuth:
lon.append(
2 * np.pi
- phi[rank_good]
- center_lonlat[rank::ntask, 0][rank_good]
)
else:
lon.append(phi[rank_good] - center_lonlat[rank::ntask, 0][rank_good])
lat.append(
(np.pi / 2 - theta[rank_good])
- center_lonlat[rank::ntask, 1][rank_good]
)
dquats = qa.mult(bore_quats, detquat)
det_lon, det_lat = center_offset_lonlat(
dquats,
center_offset=center_lonlat,
degrees=False,
is_azimuth=is_azimuth,
)
lon.append(det_lon)
lat.append(det_lat)

lon = np.unwrap(np.hstack(lon))
lat = np.hstack(lat)

Expand All @@ -131,13 +178,6 @@ def scan_range_lonlat(
latmin = np.amin(lat)
latmax = np.amax(lat)

if lonmin < -2 * np.pi:
lonmin += 2 * np.pi
lonmax += 2 * np.pi
elif lonmax > 2 * np.pi:
lonmin -= 2 * np.pi
lonmax -= 2 * np.pi

# Combine results
if obs.comm.comm_group is not None:
lonlatmin = np.zeros(2, dtype=np.float64)
Expand Down
18 changes: 18 additions & 0 deletions src/toast/scripts/toast_plot_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ def main():
help="Maximum Y viewport fraction (0.0 - 1.0)",
)

parser.add_argument(
"--cmap",
required=False,
type=str,
default="viridis",
help="The colormap name (e.g. 'inferno')",
)

parser.add_argument(
"--azimuth",
required=False,
default=False,
action="store_true",
help="Data is Azimuth / Elevation, so invert the X-axis",
)

args = parser.parse_args()

range_I = None
Expand All @@ -148,6 +164,8 @@ def main():
xmax=args.Xmax,
ymin=args.Ymin,
ymax=args.Ymax,
is_azimuth=args.azimuth,
cmap=args.cmap,
)


Expand Down
Loading

0 comments on commit b1540cb

Please sign in to comment.