Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of WCS pixelization operator #757

Merged
merged 5 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
452 changes: 259 additions & 193 deletions src/toast/ops/pixels_wcs.py

Large diffs are not rendered by default.

12 changes: 3 additions & 9 deletions src/toast/pixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,14 @@ def __init__(self, n_pix=None, n_submap=1000, local_submaps=None, comm=None):
self._local_submaps = local_submaps
self._comm = comm

self._glob2loc = None
self._n_local = 0
self._glob2loc = AlignedI64.zeros(self._n_submap)
self._glob2loc[:] = -1

if self._local_submaps is not None and len(self._local_submaps) > 0:
if np.max(self._local_submaps) > self._n_submap - 1:
raise RuntimeError("local submap indices out of range")
self._n_local = len(self._local_submaps)
self._glob2loc = AlignedI64.zeros(self._n_submap)
self._glob2loc[:] = -1
for ilocal_submap, iglobal_submap in enumerate(self._local_submaps):
self._glob2loc[iglobal_submap] = ilocal_submap

Expand Down Expand Up @@ -192,12 +191,7 @@ def global_pixel_to_submap(self, gl):
msg = "Global pixel indices exceed the maximum for the pixelization"
log.error(msg)
raise RuntimeError(msg)
if self._glob2loc is None:
msg = "PixelDistribution: no local submaps defined"
log.error(msg)
raise RuntimeError(msg)
else:
return libtoast_global_to_local(gl, self._n_pix_submap, self._glob2loc)
return libtoast_global_to_local(gl, self._n_pix_submap, self._glob2loc)

@function_timer
def global_pixel_to_local(self, gl):
Expand Down
18 changes: 10 additions & 8 deletions src/toast/pixels_io_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def submap_to_image(dist, submap, sdata, image):
"""
imshape = image.shape
n_value = imshape[0]
n_cols = imshape[1]
n_rows = imshape[2]
n_rows = imshape[1]
n_cols = imshape[2]

# Global pixel range of this submap
s_offset = submap * dist.n_pix_submap
Expand All @@ -64,7 +64,7 @@ def submap_to_image(dist, submap, sdata, image):
if pix_offset + n_copy > s_end:
n_copy = s_end - pix_offset
sbuf_offset = pix_offset + row_offset - s_offset
image[ival, col, row_offset : row_offset + n_copy] = sdata[
image[ival, row_offset : row_offset + n_copy, col] = sdata[
sbuf_offset : sbuf_offset + n_copy, ival
]

Expand All @@ -88,8 +88,8 @@ def image_to_submap(dist, image, submap, sdata, scale=1.0):
"""
imshape = image.shape
n_value = imshape[0]
n_cols = imshape[1]
n_rows = imshape[2]
n_rows = imshape[1]
n_cols = imshape[2]

# Global pixel range of this submap
s_offset = submap * dist.n_pix_submap
Expand All @@ -113,7 +113,7 @@ def image_to_submap(dist, image, submap, sdata, scale=1.0):
n_copy = s_end - pix_offset
sbuf_offset = pix_offset + row_offset - s_offset
sdata[sbuf_offset : sbuf_offset + n_copy, ival] = (
scale * image[ival, col, row_offset : row_offset + n_copy]
scale * image[ival, row_offset : row_offset + n_copy, col]
)


Expand Down Expand Up @@ -153,10 +153,12 @@ def collect_wcs_submaps(pix, comm_bytes=10000000):
allowners = np.zeros_like(owners)
dist.comm.Allreduce(owners, allowners, op=MPI.MIN)

# Create an image array for the output
# Create an image array for the output. The FITS image data is column
# major, and so our numpy array has the order of axes swapped.
image = None
image_shape = (pix.n_value, dist.wcs_shape[1], dist.wcs_shape[0])
if rank == 0:
image = np.zeros((pix.n_value,) + dist.wcs_shape, dtype=pix.dtype)
image = np.zeros(image_shape, dtype=pix.dtype)

n_val_submap = dist.n_pix_submap * pix.n_value

Expand Down
148 changes: 94 additions & 54 deletions src/toast/pointing_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,72 @@
# Copyright (c) 2015-2023 by the parties listed in the AUTHORS file.
# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.

# Pointing utility functions used by templates and operators


import numpy as np
from astropy import units as u

from . import qarray as qa
from .instrument_coords import quat_to_xieta
from .mpi import MPI
from .timing import GlobalTimers, Timer, function_timer


def center_offset_lonlat(
quats,
center_offset=None,
degrees=False,
is_azimuth=False,
):
"""Compute relative longitude / latitude from a dynamic center position.

Args:
quats (array): Input pointing quaternions
center_offset (array): Center longitude, latitude in radians for each sample
degrees (bool): If True, return longitude / latitude values in degrees
is_azimuth (bool): If True, we are using azimuth and the sign of the
longitude values should be negated

Returns:
(tuple): The (longitude, latitude) arrays

"""
if center_offset is None:
lon_rad, lat_rad, _ = qa.to_lonlat_angles(quats)
else:
if len(quats.shape) == 2:
n_samp = quats.shape[0]
else:
n_samp = 1
if center_offset.shape[0] != n_samp:
msg = f"center_offset dimensions {center_offset.shape}"
msg += f" not compatible with {n_samp} quaternion values"
raise ValueError(msg)
q_center = qa.from_lonlat_angles(
center_offset[:, 0],
center_offset[:, 1],
np.zeros_like(center_offset[:, 0]),
)
q_final = qa.mult(qa.inv(q_center), quats)
lon_rad, lat_rad, _ = quat_to_xieta(q_final)
if is_azimuth:
lon_rad = 2 * np.pi - lon_rad
# Normalize range
shift = lon_rad >= 2 * np.pi
lon_rad[shift] -= 2 * np.pi
shift = lon_rad < 0
lon_rad[shift] += 2 * np.pi
# Convert units
if degrees:
lon = np.degrees(lon_rad)
lat = np.degrees(lat_rad)
else:
lon = lon_rad
lat = lat_rad
return (lon, lat)


@function_timer
def scan_range_lonlat(
obs,
Expand Down Expand Up @@ -50,20 +104,36 @@ def scan_range_lonlat(
fov = obs.telescope.focalplane.field_of_view
fp_radius = 0.5 * fov.to_value(u.radian)

# The observation samples we are considering
if samples is None:
slc = slice(0, obs.n_local_samples, 1)
else:
slc = samples

# Get the flags if needed.
fdata = None
# Apply the flags to boresight pointing if needed.
bore_quats = np.array(obs.shared[boresight].data[slc, :])
if flags is not None:
fdata = np.array(obs.shared[flags][slc])
fdata = np.array(obs.shared[flags].data[slc])
fdata &= flag_mask
bore_quats = bore_quats[fdata == 0, :]

# work in parallel
# The remaining good samples we have left
n_good = bore_quats.shape[0]

# Check that the top of the focalplane is below the zenith
_, el_bore, _ = qa.to_lonlat_angles(bore_quats)
elmax_bore = np.amax(el_bore)
if elmax_bore + fp_radius > np.pi / 2:
msg = f"The scan range includes the zenith."
msg += f" Max boresight elevation is {np.degrees(elmax_bore)} deg"
msg += f" and focalplane radius is {np.degrees(fp_radius)} deg."
msg += " Scan range facility cannot handle this case."
raise RuntimeError(msg)

# Work in parallel
rank = obs.comm.group_rank
ntask = obs.comm.group_size
rank_slice = slice(rank, n_good, ntask)

# Create a fake focalplane of detectors in a circle around the boresight
xaxis, yaxis, zaxis = np.eye(3)
Expand All @@ -76,52 +146,29 @@ def scan_range_lonlat(
detquat = qa.mult(phirot, thetarot)
detquats.append(detquat)

# Get fake detector pointing

# Get source center positions if needed
center_lonlat = None
if center_offset is not None:
center_lonlat = np.array(obs.shared[center_offset][slc, :])
center_lonlat = np.array(
(obs.shared[center_offset].data[slc, :])[rank_slice, :]
)
# center_offset is in degrees
center_lonlat[:, :] *= np.pi / 180.0

lon = []
lat = []
quats = obs.shared[boresight][slc, :][rank::ntask].copy()
rank_good = slice(None)
if fdata is not None:
rank_good = fdata[rank::ntask] == 0

# Check that the top of the focalplane is below the zenith
theta_bore, _, _ = qa.to_iso_angles(quats)
el_bore = np.pi / 2 - theta_bore[rank_good]
elmax_bore = np.amax(el_bore)
if elmax_bore + fp_radius > np.pi / 2:
msg = f"The scan range includes the zenith."
msg += f" Max boresight elevation is {np.degrees(elmax_bore)} deg"
msg += f" and focalplane radius is {np.degrees(fp_radius)} deg."
msg += " Scan range facility cannot handle this case."
raise RuntimeError(msg)

# Compute pointing of fake detectors
lon = list()
lat = list()
for idet, detquat in enumerate(detquats):
theta, phi, _ = qa.to_iso_angles(qa.mult(quats, detquat))
if center_lonlat is None:
if is_azimuth:
lon.append(2 * np.pi - phi[rank_good])
else:
lon.append(phi[rank_good])
lat.append(np.pi / 2 - theta[rank_good])
else:
if is_azimuth:
lon.append(
2 * np.pi
- phi[rank_good]
- center_lonlat[rank::ntask, 0][rank_good]
)
else:
lon.append(phi[rank_good] - center_lonlat[rank::ntask, 0][rank_good])
lat.append(
(np.pi / 2 - theta[rank_good])
- center_lonlat[rank::ntask, 1][rank_good]
)
dquats = qa.mult(bore_quats, detquat)
det_lon, det_lat = center_offset_lonlat(
dquats,
center_offset=center_lonlat,
degrees=False,
is_azimuth=is_azimuth,
)
lon.append(det_lon)
lat.append(det_lat)

lon = np.unwrap(np.hstack(lon))
lat = np.hstack(lat)

Expand All @@ -131,13 +178,6 @@ def scan_range_lonlat(
latmin = np.amin(lat)
latmax = np.amax(lat)

if lonmin < -2 * np.pi:
lonmin += 2 * np.pi
lonmax += 2 * np.pi
elif lonmax > 2 * np.pi:
lonmin -= 2 * np.pi
lonmax -= 2 * np.pi

# Combine results
if obs.comm.comm_group is not None:
lonlatmin = np.zeros(2, dtype=np.float64)
Expand Down
18 changes: 18 additions & 0 deletions src/toast/scripts/toast_plot_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ def main():
help="Maximum Y viewport fraction (0.0 - 1.0)",
)

parser.add_argument(
"--cmap",
required=False,
type=str,
default="viridis",
help="The colormap name (e.g. 'inferno')",
)

parser.add_argument(
"--azimuth",
required=False,
default=False,
action="store_true",
help="Data is Azimuth / Elevation, so invert the X-axis",
)

args = parser.parse_args()

range_I = None
Expand All @@ -148,6 +164,8 @@ def main():
xmax=args.Xmax,
ymin=args.Ymin,
ymax=args.Ymax,
is_azimuth=args.azimuth,
cmap=args.cmap,
)


Expand Down
Loading