Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-3607: descriptive names, inputs, and outputs for wcs transforms #8524

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions jwst/assign_mtwcs/moving_target_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import logging
from copy import deepcopy
import numpy as np
from astropy.modeling.models import Shift, Identity
from astropy.modeling.models import Shift
from gwcs import WCS
from gwcs import coordinate_frames as cf
from jwst.assign_wcs.util import wl_identity

from stdatamodels.jwst import datamodels

Expand Down Expand Up @@ -99,7 +100,7 @@ def add_mt_frame(wcs, ra_average, dec_average, mt_ra, mt_dec):
if isinstance(mt, cf.CelestialFrame):
transform_to_mt = Shift(rdel) & Shift(ddel)
elif isinstance(mt, cf.CompositeFrame):
transform_to_mt = Shift(rdel) & Shift(ddel) & Identity(1)
transform_to_mt = Shift(rdel) & Shift(ddel) & wl_identity()
else:
raise ValueError("Unrecognized coordinate frame.")

Expand Down
40 changes: 29 additions & 11 deletions jwst/assign_wcs/miri.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import numpy as np
from astropy.modeling import models
from astropy.modeling.models import Identity
from astropy import coordinates as coord
from astropy import units as u
from astropy.io import fits
Expand All @@ -18,7 +19,8 @@
from . import pointing
from .util import (not_implemented_mode, subarray_transform,
velocity_correction, transform_bbox_from_shape,
bounding_box_from_subarray)
bounding_box_from_subarray,
wl_identity)


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -151,6 +153,9 @@
distortion.bounding_box = transform_bbox_from_shape(input_model.data.shape)
else:
distortion.bounding_box = bbox
distortion.inputs = ("x", "y")
distortion.outputs = ("v2", "v3")
distortion.name = "imaging_distortion"

Check warning on line 158 in jwst/assign_wcs/miri.py

View check run for this annotation

Codecov / codecov/patch

jwst/assign_wcs/miri.py#L156-L158

Added lines #L156 - L158 were not covered by tests
return distortion


Expand Down Expand Up @@ -195,14 +200,14 @@
# Create the transforms
dettotel = lrs_distortion(input_model, reference_files)
v2v3tosky = pointing.v23tosky(input_model)
teltosky = v2v3tosky & models.Identity(1)
teltosky = v2v3tosky & wl_identity()

# Compute differential velocity aberration (DVA) correction:
va_corr = pointing.dva_corr_model(
va_scale=input_model.meta.velocity_aberration.scale_factor,
v2_ref=input_model.meta.wcsinfo.v2_ref,
v3_ref=input_model.meta.wcsinfo.v3_ref
) & models.Identity(1)
) & wl_identity()

# Put the transforms together into a single pipeline
pipeline = [(detector, dettotel),
Expand Down Expand Up @@ -357,6 +362,10 @@
# Bounding box is the subarray bounding box, because we're assuming subarray coordinates passed in
dettotel.bounding_box = bb_sub[::-1]

dettotel.name = "lrs_distortion"
dettotel.inputs = ("x_direct", "y_direct")
dettotel.outputs = ("v2", "v3", "lam")

return dettotel

def ifu(input_model, reference_files):
Expand Down Expand Up @@ -399,9 +408,9 @@
va_scale=input_model.meta.velocity_aberration.scale_factor,
v2_ref=input_model.meta.wcsinfo.v2_ref,
v3_ref=input_model.meta.wcsinfo.v3_ref
) & models.Identity(1)
) & wl_identity()

tel2sky = pointing.v23tosky(input_model) & models.Identity(1)
tel2sky = pointing.v23tosky(input_model) & wl_identity()

# Put the transforms together into a single transform
det2abl.bounding_box = transform_bbox_from_shape(input_model.data.shape)
Expand Down Expand Up @@ -473,20 +482,25 @@
with WavelengthrangeModel(reference_files['wavelengthrange']) as f:
wr = dict(zip(f.waverange_selector, f.wavelengthrange))

det_labels = ('x_direct', 'y_direct')
abl_labels = ('alpha', 'beta', 'lam')
ch_dict = {}
for c in channel:
cb = c + band
mapper = MIRI_AB2Slice(bzero[cb], bdel[cb], c)
lm = selector.LabelMapper(inputs=('alpha', 'beta', 'lam'),
lm = selector.LabelMapper(inputs=abl_labels,
mapper=mapper, inputs_mapping=models.Mapping((1,), n_inputs=3))
ch_dict[tuple(wr[cb])] = lm

alpha_beta_mapper = selector.LabelMapperRange(('alpha', 'beta', 'lam'), ch_dict,
alpha_beta_mapper = selector.LabelMapperRange(abl_labels, ch_dict,
models.Mapping((2,)))
label_mapper.inverse = alpha_beta_mapper

det2alpha_beta = selector.RegionsSelector(('x', 'y'), ('alpha', 'beta', 'lam'),
det2alpha_beta = selector.RegionsSelector(det_labels, abl_labels,
label_mapper=label_mapper, selector=transforms)
det2alpha_beta.name = "detector_to_alpha_beta"
det2alpha_beta.inputs = det_labels
det2alpha_beta.outputs = abl_labels
return det2alpha_beta


Expand Down Expand Up @@ -536,13 +550,17 @@
v23c = v23_spatial & ident1
sel[ch] = v23c

wave_range_mapper = selector.LabelMapperRange(('alpha', 'beta', 'lam'), dict_mapper,
abl_labels = ('alpha', 'beta', 'lam')
v2v3_labels = ('v2', 'v3', 'lam')
wave_range_mapper = selector.LabelMapperRange(abl_labels, dict_mapper,
inputs_mapping=models.Mapping([2, ]))
wave_range_mapper.inverse = wave_range_mapper.copy()
abl2v2v3l = selector.RegionsSelector(('alpha', 'beta', 'lam'), ('v2', 'v3', 'lam'),
abl2v2v3l = selector.RegionsSelector(abl_labels, v2v3_labels,
label_mapper=wave_range_mapper,
selector=sel)

abl2v2v3l.name = "alpha_beta_to_v2v3"
abl2v2v3l.inputs = abl_labels
abl2v2v3l.outputs = v2v3_labels
return abl2v2v3l


Expand Down
35 changes: 25 additions & 10 deletions jwst/assign_wcs/nircam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from . import pointing
from .util import (not_implemented_mode, subarray_transform, velocity_correction,
transform_bbox_from_shape, bounding_box_from_subarray)
transform_bbox_from_shape, bounding_box_from_subarray,
wl_order_identity)
from ..lib.reffile_utils import find_row


Expand Down Expand Up @@ -150,6 +151,10 @@ def imaging_distortion(input_model, reference_files):
transform.bounding_box = transform_bbox_from_shape(input_model.data.shape)
else:
transform.bounding_box = bbox

transform.name = "imaging_distortion"
transform.inputs = ('x_direct', 'y_direct')
transform.outputs = ('v2', 'v3')
return transform


Expand Down Expand Up @@ -259,40 +264,47 @@ def tsgrism(input_model, reference_files):
setdec = Const1D(input_model.meta.wcsinfo.dec_ref)
setdec.inverse = Const1D(input_model.meta.wcsinfo.dec_ref)

#wl_order_identity() = Identity(2)
#wl_order_identity().inputs = ('wavelength', 'order')
#wl_order_identity().outputs = ('wavelength', 'order')

# x, y, order in goes to transform to full array location and order
# get the shift to full frame coordinates
sub_trans = subarray_transform(input_model)
if sub_trans is not None:
sub2direct = (sub_trans & Identity(1) | Mapping((0, 1, 0, 1, 2)) |
(Identity(2) & xcenter & ycenter & Identity(1)) |
(wl_order_identity() & xcenter & ycenter & Identity(1)) |
det2det)
else:
sub2direct = (Mapping((0, 1, 0, 1, 2)) |
(Identity(2) & xcenter & ycenter & Identity(1)) |
(wl_order_identity() & xcenter & ycenter & Identity(1)) |
det2det)
sub2direct.name = "grism_to_image"
sub2direct.inputs = ('x_grism', 'y_grism', 'order')
sub2direct.outputs = ('x_direct', 'y_direct', 'wavelength', 'order')

# take us from full frame detector to v2v3
distortion = imaging_distortion(input_model, reference_files) & Identity(2)
distortion = imaging_distortion(input_model, reference_files) & wl_order_identity()

# Compute differential velocity aberration (DVA) correction:
va_corr = pointing.dva_corr_model(
va_scale=input_model.meta.velocity_aberration.scale_factor,
v2_ref=input_model.meta.wcsinfo.v2_ref,
v3_ref=input_model.meta.wcsinfo.v3_ref
) & Identity(2)
) & wl_order_identity()

# v2v3 to the sky
# remap the tel2sky inverse as well since we can feed it the values of
# crval1, crval2 which correspond to crpix1, crpix2. This leaves
# us with a calling structure:
# (x, y, order) <-> (wavelength, order)
tel2sky = pointing.v23tosky(input_model) & Identity(2)
tel2sky = pointing.v23tosky(input_model) & wl_order_identity()
t2skyinverse = tel2sky.inverse
newinverse = Mapping((0, 1, 0, 1)) | setra & setdec & Identity(2) | t2skyinverse
newinverse = Mapping((0, 1, 0, 1)) | setra & setdec & wl_order_identity() | t2skyinverse
tel2sky.inverse = newinverse

pipeline = [(frames['grism_detector'], sub2direct),
(frames['direct_image'], distortion),
(frames['detector'], distortion),
(frames['v2v3'], va_corr),
(frames['v2v3vacorr'], tel2sky),
(frames['world'], None)]
Expand Down Expand Up @@ -410,6 +422,9 @@ def wfss(input_model, reference_files):
velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys)
log.info("Added Barycentric velocity correction: {}".format(velocity_corr[1].amplitude.value))
det2det = det2det | Mapping((0, 1, 2, 3)) | Identity(2) & velocity_corr & Identity(1)
det2det.name = "grism_dispersion"
det2det.inputs = ('x_grism', 'y_grism', 'x0', 'y0', 'order')
det2det.outputs = ('x_direct', 'y_direct', 'wavelength', 'order')

# create the pipeline to construct a WCS object for the whole image
# which can translate ra,dec to image frame reference pixels
Expand All @@ -431,7 +446,7 @@ def wfss(input_model, reference_files):
world = image_pipeline.pop()[0]
world.name = 'sky'
for cframe, trans in image_pipeline:
trans = trans & (Identity(2))
trans = trans & (wl_order_identity())
name = cframe.name
cframe.name = name + 'spatial'
spatial_and_spectral = cf.CompositeFrame([cframe, spec],
Expand All @@ -456,7 +471,7 @@ def create_coord_frames():
spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,),
axes_names=('wavelength',))
frames = {'grism_detector': gdetector,
'direct_image': cf.CompositeFrame([detector, spec], name='direct_image'),
'detector': cf.CompositeFrame([detector, spec], name='detector'),
'v2v3': cf.CompositeFrame([v2v3_spatial, spec], name='v2v3'),
'v2v3vacorr': cf.CompositeFrame([v2v3vacorr_spatial, spec], name='v2v3vacorr'),
'world': cf.CompositeFrame([sky_frame, spec], name='world')
Expand Down
11 changes: 9 additions & 2 deletions jwst/assign_wcs/niriss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .util import (not_implemented_mode, subarray_transform,
velocity_correction, bounding_box_from_subarray,
transform_bbox_from_shape)
transform_bbox_from_shape, wl_order_identity)
from . import pointing
from ..lib.reffile_utils import find_row

Expand Down Expand Up @@ -312,6 +312,10 @@ def imaging_distortion(input_model, reference_files):
distortion.bounding_box = transform_bbox_from_shape(input_model.data.shape)
else:
distortion.bounding_box = bbox

distortion.inputs = ('x_direct', 'y_direct')
distortion.outputs = ('v2', 'v3')
distortion.name = "imaging_distortion"
return distortion


Expand Down Expand Up @@ -446,6 +450,9 @@ def wfss(input_model, reference_files):
velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys)
log.info("Added Barycentric velocity correction: {}".format(velocity_corr[1].amplitude.value))
det2det = det2det | Mapping((0, 1, 2, 3)) | Identity(2) & velocity_corr & Identity(1)
det2det.name = "grism_dispersion"
det2det.inputs = ('x_grism', 'y_grism', 'x0', 'y0', 'order')
det2det.outputs = ('x_direct', 'y_direct', 'wavelength', 'order')

# create the pipeline to construct a WCS object for the whole image
# which can translate ra,dec to image frame reference pixels
Expand All @@ -471,7 +478,7 @@ def wfss(input_model, reference_files):
world = image_pipeline.pop()[0]
world.name = 'sky'
for cframe, trans in image_pipeline:
trans = trans & (Identity(2))
trans = trans & (wl_order_identity())
name = cframe.name
cframe.name = name + 'spatial'
spatial_and_spectral = cf.CompositeFrame([cframe, spec], name=name)
Expand Down
Loading
Loading