From 2db8861922ad4d7fdcad4884efb3df7f92ce8c15 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 6 May 2024 09:40:23 -0400 Subject: [PATCH 1/4] intermediate progress, fixes except nirspec --- jwst/assign_wcs/miri.py | 40 +++++++--- jwst/assign_wcs/nircam.py | 31 +++++--- jwst/assign_wcs/niriss.py | 8 +- jwst/assign_wcs/nirspec.py | 105 +++++++++++++++++++------- jwst/assign_wcs/pointing.py | 14 +++- jwst/assign_wcs/tests/test_miri.py | 47 ++++++++++++ jwst/assign_wcs/tests/test_nircam.py | 47 +++++++++--- jwst/assign_wcs/tests/test_niriss.py | 25 ++++++ jwst/assign_wcs/tests/test_nirspec.py | 86 +++++++++++++++++++++ jwst/assign_wcs/util.py | 22 ++++++ 10 files changed, 362 insertions(+), 63 deletions(-) diff --git a/jwst/assign_wcs/miri.py b/jwst/assign_wcs/miri.py index f8f4372687..28dc3a5d4e 100644 --- a/jwst/assign_wcs/miri.py +++ b/jwst/assign_wcs/miri.py @@ -2,6 +2,7 @@ import logging import numpy as np from astropy.modeling import models +from astropy.modeling.models import Identity from astropy import coordinates as coord from astropy import units as u from astropy.io import fits @@ -18,7 +19,8 @@ from . import pointing from .util import (not_implemented_mode, subarray_transform, velocity_correction, transform_bbox_from_shape, - bounding_box_from_subarray) + bounding_box_from_subarray, + wl_identity) log = logging.getLogger(__name__) @@ -151,6 +153,9 @@ def imaging_distortion(input_model, reference_files): distortion.bounding_box = transform_bbox_from_shape(input_model.data.shape) else: distortion.bounding_box = bbox + distortion.inputs = ("x", "y") + distortion.outputs = ("v2", "v3") + distortion.name = "imaging_distortion" return distortion @@ -195,14 +200,14 @@ def lrs(input_model, reference_files): # Create the transforms dettotel = lrs_distortion(input_model, reference_files) v2v3tosky = pointing.v23tosky(input_model) - teltosky = v2v3tosky & models.Identity(1) + teltosky = v2v3tosky & wl_identity() # Compute differential velocity aberration (DVA) correction: va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, v3_ref=input_model.meta.wcsinfo.v3_ref - ) & models.Identity(1) + ) & wl_identity() # Put the transforms together into a single pipeline pipeline = [(detector, dettotel), @@ -357,6 +362,10 @@ def lrs_distortion(input_model, reference_files): # Bounding box is the subarray bounding box, because we're assuming subarray coordinates passed in dettotel.bounding_box = bb_sub[::-1] + dettotel.name = "lrs_distortion" + dettotel.inputs = ("x", "y") + dettotel.outputs = ("v2", "v3", "lam") + return dettotel def ifu(input_model, reference_files): @@ -399,9 +408,9 @@ def ifu(input_model, reference_files): va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, v3_ref=input_model.meta.wcsinfo.v3_ref - ) & models.Identity(1) + ) & wl_identity() - tel2sky = pointing.v23tosky(input_model) & models.Identity(1) + tel2sky = pointing.v23tosky(input_model) & wl_identity() # Put the transforms together into a single transform det2abl.bounding_box = transform_bbox_from_shape(input_model.data.shape) @@ -473,20 +482,25 @@ def detector_to_abl(input_model, reference_files): with WavelengthrangeModel(reference_files['wavelengthrange']) as f: wr = dict(zip(f.waverange_selector, f.wavelengthrange)) + det_labels = ('x', 'y') + abl_labels = ('alpha', 'beta', 'lam') ch_dict = {} for c in channel: cb = c + band mapper = MIRI_AB2Slice(bzero[cb], bdel[cb], c) - lm = selector.LabelMapper(inputs=('alpha', 'beta', 'lam'), + lm = selector.LabelMapper(inputs=abl_labels, mapper=mapper, inputs_mapping=models.Mapping((1,), n_inputs=3)) ch_dict[tuple(wr[cb])] = lm - alpha_beta_mapper = selector.LabelMapperRange(('alpha', 'beta', 'lam'), ch_dict, + alpha_beta_mapper = selector.LabelMapperRange(abl_labels, ch_dict, models.Mapping((2,))) label_mapper.inverse = alpha_beta_mapper - det2alpha_beta = selector.RegionsSelector(('x', 'y'), ('alpha', 'beta', 'lam'), + det2alpha_beta = selector.RegionsSelector(det_labels, abl_labels, label_mapper=label_mapper, selector=transforms) + det2alpha_beta.name = "detector_to_alpha_beta" + det2alpha_beta.inputs = det_labels + det2alpha_beta.outputs = abl_labels return det2alpha_beta @@ -536,13 +550,17 @@ def abl_to_v2v3l(input_model, reference_files): v23c = v23_spatial & ident1 sel[ch] = v23c - wave_range_mapper = selector.LabelMapperRange(('alpha', 'beta', 'lam'), dict_mapper, + abl_labels = ('alpha', 'beta', 'lam') + v2v3_labels = ('v2', 'v3', 'lam') + wave_range_mapper = selector.LabelMapperRange(abl_labels, dict_mapper, inputs_mapping=models.Mapping([2, ])) wave_range_mapper.inverse = wave_range_mapper.copy() - abl2v2v3l = selector.RegionsSelector(('alpha', 'beta', 'lam'), ('v2', 'v3', 'lam'), + abl2v2v3l = selector.RegionsSelector(abl_labels, v2v3_labels, label_mapper=wave_range_mapper, selector=sel) - + abl2v2v3l.name = "alpha_beta_to_v2v3" + abl2v2v3l.inputs = abl_labels + abl2v2v3l.outputs = v2v3_labels return abl2v2v3l diff --git a/jwst/assign_wcs/nircam.py b/jwst/assign_wcs/nircam.py index 676a4b64f6..26dc770a35 100644 --- a/jwst/assign_wcs/nircam.py +++ b/jwst/assign_wcs/nircam.py @@ -14,7 +14,8 @@ from . import pointing from .util import (not_implemented_mode, subarray_transform, velocity_correction, - transform_bbox_from_shape, bounding_box_from_subarray) + transform_bbox_from_shape, bounding_box_from_subarray, + wl_order_identity) from ..lib.reffile_utils import find_row @@ -150,6 +151,10 @@ def imaging_distortion(input_model, reference_files): transform.bounding_box = transform_bbox_from_shape(input_model.data.shape) else: transform.bounding_box = bbox + + transform.name = "imaging_distortion" + transform.inputs = ('x', 'y') + transform.outputs = ('v2', 'v3') return transform @@ -259,40 +264,46 @@ def tsgrism(input_model, reference_files): setdec = Const1D(input_model.meta.wcsinfo.dec_ref) setdec.inverse = Const1D(input_model.meta.wcsinfo.dec_ref) + #wl_order_identity() = Identity(2) + #wl_order_identity().inputs = ('wavelength', 'order') + #wl_order_identity().outputs = ('wavelength', 'order') + # x, y, order in goes to transform to full array location and order # get the shift to full frame coordinates sub_trans = subarray_transform(input_model) if sub_trans is not None: sub2direct = (sub_trans & Identity(1) | Mapping((0, 1, 0, 1, 2)) | - (Identity(2) & xcenter & ycenter & Identity(1)) | + (wl_order_identity() & xcenter & ycenter & Identity(1)) | det2det) else: sub2direct = (Mapping((0, 1, 0, 1, 2)) | - (Identity(2) & xcenter & ycenter & Identity(1)) | + (wl_order_identity() & xcenter & ycenter & Identity(1)) | det2det) + sub2direct.name = "grism2image" + sub2direct.inputs = ('x', 'y', 'order') # take us from full frame detector to v2v3 - distortion = imaging_distortion(input_model, reference_files) & Identity(2) + distortion = imaging_distortion(input_model, reference_files) & wl_order_identity() # Compute differential velocity aberration (DVA) correction: va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, v3_ref=input_model.meta.wcsinfo.v3_ref - ) & Identity(2) + ) & wl_order_identity() # v2v3 to the sky # remap the tel2sky inverse as well since we can feed it the values of # crval1, crval2 which correspond to crpix1, crpix2. This leaves # us with a calling structure: # (x, y, order) <-> (wavelength, order) - tel2sky = pointing.v23tosky(input_model) & Identity(2) + tel2sky = pointing.v23tosky(input_model) & wl_order_identity() t2skyinverse = tel2sky.inverse - newinverse = Mapping((0, 1, 0, 1)) | setra & setdec & Identity(2) | t2skyinverse + newinverse = Mapping((0, 1, 0, 1)) | setra & setdec & wl_order_identity() | t2skyinverse tel2sky.inverse = newinverse pipeline = [(frames['grism_detector'], sub2direct), - (frames['direct_image'], distortion), + (frames['detector'], distortion), (frames['v2v3'], va_corr), (frames['v2v3vacorr'], tel2sky), (frames['world'], None)] @@ -431,7 +442,7 @@ def wfss(input_model, reference_files): world = image_pipeline.pop()[0] world.name = 'sky' for cframe, trans in image_pipeline: - trans = trans & (Identity(2)) + trans = trans & (wl_order_identity()) name = cframe.name cframe.name = name + 'spatial' spatial_and_spectral = cf.CompositeFrame([cframe, spec], @@ -456,7 +467,7 @@ def create_coord_frames(): spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), axes_names=('wavelength',)) frames = {'grism_detector': gdetector, - 'direct_image': cf.CompositeFrame([detector, spec], name='direct_image'), + 'detector': cf.CompositeFrame([detector, spec], name='detector'), 'v2v3': cf.CompositeFrame([v2v3_spatial, spec], name='v2v3'), 'v2v3vacorr': cf.CompositeFrame([v2v3vacorr_spatial, spec], name='v2v3vacorr'), 'world': cf.CompositeFrame([sky_frame, spec], name='world') diff --git a/jwst/assign_wcs/niriss.py b/jwst/assign_wcs/niriss.py index 2da04a5751..f1ec23cc37 100644 --- a/jwst/assign_wcs/niriss.py +++ b/jwst/assign_wcs/niriss.py @@ -16,7 +16,7 @@ from .util import (not_implemented_mode, subarray_transform, velocity_correction, bounding_box_from_subarray, - transform_bbox_from_shape) + transform_bbox_from_shape, wl_order_identity) from . import pointing from ..lib.reffile_utils import find_row @@ -312,6 +312,10 @@ def imaging_distortion(input_model, reference_files): distortion.bounding_box = transform_bbox_from_shape(input_model.data.shape) else: distortion.bounding_box = bbox + + distortion.inputs = ('x', 'y') + distortion.outputs = ('v2', 'v3') + distortion.name = "imaging_distortion" return distortion @@ -471,7 +475,7 @@ def wfss(input_model, reference_files): world = image_pipeline.pop()[0] world.name = 'sky' for cframe, trans in image_pipeline: - trans = trans & (Identity(2)) + trans = trans & (wl_order_identity()) name = cframe.name cframe.name = name + 'spatial' spatial_and_spectral = cf.CompositeFrame([cframe, spec], name=name) diff --git a/jwst/assign_wcs/nirspec.py b/jwst/assign_wcs/nirspec.py index c6e3e8c2ef..3d6b3b0bb0 100644 --- a/jwst/assign_wcs/nirspec.py +++ b/jwst/assign_wcs/nirspec.py @@ -27,7 +27,8 @@ MSAFileError, NoDataOnDetectorError, not_implemented_mode, - velocity_correction + velocity_correction, + wl_identity ) from . import pointing from ..lib.exposure_types import is_nrs_ifu_lamp @@ -87,6 +88,8 @@ def imaging(input_model, reference_files): det2gwa = detector_to_gwa(reference_files, input_model.meta.instrument.detector, disperser) gwa_through = Const1D(-1) * Identity(1) & Const1D(-1) * Identity(1) & Identity(1) + gwa_through.inputs = ('x', 'y', 'z') + gwa_through.outputs = ('x', 'y', 'z') angles = [disperser['theta_x'], disperser['theta_y'], disperser['theta_z'], disperser['tilt_y']] @@ -96,6 +99,8 @@ def imaging(input_model, reference_files): col_model = CollimatorModel(reference_files['collimator']) col = col_model.model col_model.close() + col.inputs = ('x', 'y') + col.outputs = ('x', 'y') # Get the default spectral order and wavelength range and record them in the model. sporder, wrange = get_spectral_order_wrange(input_model, reference_files['wavelengthrange']) @@ -106,8 +111,11 @@ def imaging(input_model, reference_files): lam = wrange[0] + (wrange[1] - wrange[0]) * .5 lam_model = Mapping((0, 1, 1)) | Identity(2) & Const1D(lam) + lam_model.inputs = ('x', 'y') + lam_model.outputs = ('x', 'y', 'lam') gwa2msa = gwa_through | rotation | dircos2unitless | col | lam_model + gwa2msa.name = "gwa_to_msa" gwa2msa.inverse = col.inverse | dircos2unitless.inverse | rotation.inverse | gwa_through # Create coordinate frames in the NIRSPEC WCS pipeline @@ -117,6 +125,9 @@ def imaging(input_model, reference_files): # MSA to OTEIP transform msa2ote = msa_to_oteip(reference_files) msa2oteip = msa2ote | Mapping((0, 1), n_inputs=3) + msa2oteip.name = "msa_to_oteip" + msa2oteip.inputs = ('x', 'y', 'lam') + msa2oteip.outputs = ('xan', 'yan') map1 = Mapping((0, 1, 0, 1)) minv = msa2ote.inverse del minv.inverse @@ -125,6 +136,9 @@ def imaging(input_model, reference_files): # OTEIP to V2,V3 transform with OTEModel(reference_files['ote']) as f: oteip2v23 = f.model + oteip2v23.name = "oteip_to_v2v3" + oteip2v23.inputs = ('xan', 'yan') + oteip2v23.outputs = ('v2', 'v3') # Compute differential velocity aberration (DVA) correction: va_corr = pointing.dva_corr_model( @@ -213,6 +227,8 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): # DMS to SCA transform dms2detector = dms_to_sca(input_model) # DETECTOR to GWA transform + + # what are the two additional inputs? det2gwa = Identity(2) & detector_to_gwa(reference_files, input_model.meta.instrument.detector, disperser) @@ -221,7 +237,7 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): gwa2slit = gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range) # SLIT to MSA transform - slit2slicer = ifuslit_to_slicer(slits, reference_files, input_model) + slit2slicer = ifuslit_to_slicer(slits, reference_files) # SLICER to MSA Entrance slicer2msa = slicer_to_msa(reference_files) @@ -235,8 +251,8 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): if input_model.meta.instrument.filter == 'OPAQUE' or is_lamp_exposure: # If filter is "OPAQUE" or if internal lamp exposure the NIRSPEC WCS pipeline stops at the MSA. pipeline = [(det, dms2detector), - (sca, det2gwa.rename('detector2gwa')), - (gwa, gwa2slit.rename('gwa2slit')), + (sca, det2gwa), + (gwa, gwa2slit), (slit_frame, slit2slicer), ('slicer', slicer2msa), (msa_frame, None)] @@ -252,10 +268,10 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, v3_ref=input_model.meta.wcsinfo.v3_ref - ) & Identity(1) + ) & wl_identity() # V2, V3 to sky - tel2sky = pointing.v23tosky(input_model) & Identity(1) + tel2sky = pointing.v23tosky(input_model) & wl_identity() # Create coordinate frames in the NIRSPEC WCS pipeline" # @@ -265,12 +281,12 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): # "detector", "gwa", "slit_frame", "msa_frame", "oteip", "v2v3", "world" pipeline = [(det, dms2detector), - (sca, det2gwa.rename('detector2gwa')), - (gwa, gwa2slit.rename('gwa2slit')), + (sca, det2gwa), + (gwa, gwa2slit), (slit_frame, slit2slicer), ('slicer', slicer2msa), - (msa_frame, msa2oteip.rename('msa2oteip')), - (oteip, oteip2v23.rename('oteip2v23')), + (msa_frame, msa2oteip), + (oteip, oteip2v23), (v2v3, va_corr), (v2v3vacorr, tel2sky), (world, None)] @@ -332,20 +348,17 @@ def slitlets_wcs(input_model, reference_files, open_slits_id): # DMS to SCA transform dms2detector = dms_to_sca(input_model) - dms2detector.name = 'dms2sca' + # DETECTOR to GWA transform det2gwa = Identity(2) & detector_to_gwa(reference_files, input_model.meta.instrument.detector, disperser) - det2gwa.name = "det2gwa" # GWA to SLIT gwa2slit = gwa_to_slit(open_slits_id, input_model, disperser, reference_files) - gwa2slit.name = "gwa2slit" # SLIT to MSA transform slit2msa = slit_to_msa(open_slits_id, reference_files['msa']) - slit2msa.name = "slit2msa" # Create coordinate frames in the NIRSPEC WCS pipeline" # "detector", "gwa", "slit_frame", "msa_frame", "oteip", "v2v3", "v2v3vacorr", "world" @@ -370,14 +383,13 @@ def slitlets_wcs(input_model, reference_files, open_slits_id): # OTEIP to V2,V3 transform # This includes a wavelength unit conversion from meters to microns. oteip2v23 = oteip_to_v23(reference_files, input_model) - oteip2v23.name = "oteip2v23" # Compute differential velocity aberration (DVA) correction: va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, v3_ref=input_model.meta.wcsinfo.v3_ref - ) & Identity(1) + ) & wl_identity() # V2, V3 to sky tel2sky = pointing.v23tosky(input_model) & Identity(1) @@ -813,7 +825,7 @@ def get_spectral_order_wrange(input_model, wavelengthrange_file): return order, wrange -def ifuslit_to_slicer(slits, reference_files, input_model): +def ifuslit_to_slicer(slits, reference_files): """ The transform from ``slit_frame`` to ``slicer`` frame. @@ -823,7 +835,6 @@ def ifuslit_to_slicer(slits, reference_files, input_model): A list of slit IDs for all slices. reference_files : dict {reference_type: reference_file_name} - input_model : `~jwst.datamodels.IFUImageModel` Returns ------- @@ -836,13 +847,18 @@ def ifuslit_to_slicer(slits, reference_files, input_model): for slit in slits: slitdata = ifuslicer.data[slit] slitdata_model = (get_slit_location_model(slitdata)).rename('slitdata_model') - slicer_model = slitdata_model | ifuslicer_model - - msa_transform = slicer_model + msa_transform = slitdata_model | ifuslicer_model + msa_transform.name = "ifuslit_to_slicer" + msa_transform.inputs = ('x_slit', 'y_slit') + msa_transform.outputs = ('x_slit', 'y_slit') models.append(msa_transform) ifuslicer.close() - return Slit2Msa(slits, models) + transform = Slit2Msa(slits, models) + transform.name = "ifuslit_to_slicer" + #transform.inputs = ('name', 'x_slit', 'y_slit') + #transform.outputs = ('x_msa', 'y_msa') + return transform def slicer_to_msa(reference_files): @@ -858,7 +874,10 @@ def slicer_to_msa(reference_files): slicer2fore_mapping.inverse = Identity(3) ifufore2fore_mapping = Identity(1) ifufore2fore_mapping.inverse = Mapping((0, 1, 2, 2)) - ifu_fore_transform = slicer2fore_mapping | ifufore & Identity(1) + ifu_fore_transform = slicer2fore_mapping | ifufore & wl_identity() + ifu_fore_transform.name = "slicer_to_msa" + ifu_fore_transform.inputs = ('x', 'y', 'lam') + ifu_fore_transform.outputs = ('x_msa', 'y_msa', 'lam') return ifu_fore_transform @@ -896,10 +915,14 @@ def slit_to_msa(open_slits, msafile): slitdata = msa_data[slit_id] slitdata_model = get_slit_location_model(slitdata) msa_transform = slitdata_model | msa_model + msa_transform.name = "slit_to_msa" + msa_transform.inputs = ('x', 'y', 'lam') + msa_transform.outputs = ('x', 'y', 'lam') models.append(msa_transform) slits.append(slit) msa.close() - return Slit2Msa(slits, models) + transform = Slit2Msa(slits, models) + return transform def gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range): @@ -989,11 +1012,16 @@ def gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range) msa2gwa_out = ifuslicer_transform & Identity(1) | ifupost_transform | collimator2gwa msa2bgwa = Mapping((0, 1, 2, 2)) | msa2gwa_out & Identity(1) | Mapping((3, 0, 1, 2)) | agreq bgwa2msa.inverse = msa2bgwa + bgwa2msa.name = "gwa_to_ifuslit" + bgwa2msa.inputs = ('alpha', 'beta', 'gamma') + bgwa2msa.outputs = ('x_slit', 'y_slit', 'lam') slit_models.append(bgwa2msa) ifuslicer.close() ifupost.close() - return Gwa2Slit(slits, slit_models) + transform = Gwa2Slit(slits, slit_models) + transform.name = "gwa_to_ifuslit" + return transform def gwa_to_slit(open_slits, input_model, disperser, @@ -1078,7 +1106,9 @@ def gwa_to_slit(open_slits, input_model, disperser, slit_models.append(bgwa2msa) slits.append(slit) msa.close() - return Gwa2Slit(slits, slit_models) + transform = Gwa2Slit(slits, slit_models) + transform.name = "gwa_to_slit" + return transform def angle_from_disperser(disperser, input_model): @@ -1190,6 +1220,9 @@ def detector_to_gwa(reference_files, detector, disperser): models.Shift(-1) & models.Shift(-1) | fpa | camera | u2dircos | rotation ''' model = fpa | camera | u2dircos | rotation + model.name = 'sca_to_gwa' + # input names already handled in stdatamodels Rotation3DToGWA + model.outputs = ('alpha', 'beta', 'gamma') return model @@ -1214,7 +1247,11 @@ def dms_to_sca(input_model): model = models.Shift(-2047) & models.Shift(-2047) | models.Scale(-1) & models.Scale(-1) elif detector == 'NRS1': model = models.Identity(2) - return subarray2full | model + dms2sca = subarray2full | model + dms2sca.inputs = ('x', 'y') + dms2sca.outputs = ('x', 'y') + dms2sca.name = 'dms_to_sca' + return dms2sca def mask_slit(ymin=-.55, ymax=.55): @@ -1437,7 +1474,10 @@ def ifu_msa_to_oteip(reference_files): msa2fore_mapping = Mapping((0, 1, 2, 2), name='msa2fore_mapping') msa2fore_mapping.inverse = Mapping((0, 1, 2, 2), name='fore2msa') - fore_transform = msa2fore_mapping | fore & Identity(1) + fore_transform = msa2fore_mapping | fore & wl_identity() + fore_transform.name = "msa_to_oteip" + fore_transform.inputs = ('x_msa', 'y_msa', 'lam') + fore_transform.outputs = ('xan', 'yan', 'lam') return fore_transform @@ -1460,7 +1500,11 @@ def msa_to_oteip(reference_files): fore = f.model msa2fore_mapping = Mapping((0, 1, 2, 2), name='msa2fore_mapping') msa2fore_mapping.inverse = Identity(3) - return msa2fore_mapping | (fore & Identity(1)) + transform = msa2fore_mapping | (fore & wl_identity()) + transform.name = "msa_to_oteip" + transform.inputs = ('x', 'y', 'lam') + transform.outputs = ('xan', 'yan', 'lam') + return transform def oteip_to_v23(reference_files, input_model): @@ -1489,6 +1533,9 @@ def oteip_to_v23(reference_files, input_model): # The spatial units are currently in deg. Convertin to arcsec. oteip2v23 = fore2ote_mapping | (ote & Scale(1e6)) + oteip2v23.name = "oteip_to_v2v3" + oteip2v23.inputs = ('xan', 'yan', 'lam') + oteip2v23.outputs = ('v2', 'v3', 'lam') return oteip2v23 diff --git a/jwst/assign_wcs/pointing.py b/jwst/assign_wcs/pointing.py index 88c7f7b2b4..f422cb19fc 100644 --- a/jwst/assign_wcs/pointing.py +++ b/jwst/assign_wcs/pointing.py @@ -27,6 +27,8 @@ def _v23tosky(v2_ref, v3_ref, roll_ref, ra_ref, dec_ref, wrap_v2_at=180, wrap_lo m = ((Scale(1 / 3600) & Scale(1 / 3600)) | SphericalToCartesian(wrap_lon_at=wrap_v2_at) | rot | CartesianToSpherical(wrap_lon_at=wrap_lon_at)) m.name = 'v23tosky' + m.inputs = ('v2', 'v3') + m.outputs = ('ra', 'dec') return m @@ -274,13 +276,21 @@ def dva_corr_model(va_scale, v2_ref, v3_ref): then `astropy.modeling.models.Identity` will be returned. """ + if va_scale is None or va_scale == 1: - return Identity(2) + va_corr = Identity(2) + va_corr.name = 'DVA_Correction' + va_corr.inputs = ('v2', 'v3') + va_corr.outputs = ('v2', 'v3') + return va_corr if va_scale <= 0: raise ValueError("'Velocity aberration scale must be a positive number.") va_corr = Scale(va_scale, name='dva_scale_v2') & Scale(va_scale, name='dva_scale_v3') + va_corr.name = 'DVA_Correction' + va_corr.inputs = ('v2', 'v3') + va_corr.outputs = ('v2', 'v3') if v2_ref is None: v2_ref = 0 @@ -289,6 +299,7 @@ def dva_corr_model(va_scale, v2_ref, v3_ref): v3_ref = 0 if v2_ref == 0 and v3_ref == 0: + return va_corr # NOTE: it is assumed that v2, v3 angles and va scale are small enough @@ -298,5 +309,4 @@ def dva_corr_model(va_scale, v2_ref, v3_ref): v3_shift = (1 - va_scale) * v3_ref va_corr |= Shift(v2_shift, name='dva_v2_shift') & Shift(v3_shift, name='dva_v3_shift') - va_corr.name = 'DVA_Correction' return va_corr diff --git a/jwst/assign_wcs/tests/test_miri.py b/jwst/assign_wcs/tests/test_miri.py index a698ec30af..f548a94ae4 100644 --- a/jwst/assign_wcs/tests/test_miri.py +++ b/jwst/assign_wcs/tests/test_miri.py @@ -11,6 +11,7 @@ from gwcs import wcs from numpy.testing import assert_allclose +import pytest from stdatamodels.jwst.datamodels import ImageModel, CubeModel from jwst.assign_wcs import miri @@ -51,6 +52,27 @@ def create_hdul(detector, channel, band): return hdul +@pytest.fixture +def create_hdul_lrs(): + hdul = fits.HDUList() + phdu = fits.PrimaryHDU() + phdu.header['telescop'] = "JWST" + phdu.header['filename'] = "test" + phdu.header['instrume'] = 'MIRI' + phdu.header['detector'] = 'MIRIMAGE' + phdu.header['CHANNEL'] = '34' + phdu.header['BAND'] = 'SHORT' + phdu.header['time-obs'] = '8:59:37' + phdu.header['date-obs'] = '2017-09-05' + phdu.header['exp_type'] = 'MIR_LRS-FIXEDSLIT' + scihdu = fits.ImageHDU() + scihdu.header['EXTNAME'] = "SCI" + scihdu.header.update(wcs_kw) + hdul.append(phdu) + hdul.append(scihdu) + return hdul + + def create_datamodel(hdul): im = ImageModel(hdul) ref = create_reference_files(im) @@ -164,6 +186,31 @@ def test_mrs_tso_bounding_box(): assert_allclose(cube.meta.wcs.bounding_box, ((-.5, 49.5), (-.5, 39.5))) +def test_transform_metadata_mrs(): + hdul = create_hdul(detector="MIRIFULONG", channel="34", band="MEDIUM") + wcs = create_datamodel(hdul).meta.wcs + + assert wcs.get_transform("detector", "alpha_beta").inputs == ('x', 'y') + assert wcs.get_transform("detector", "alpha_beta").outputs == ('alpha', 'beta', 'lam') + assert wcs.get_transform("alpha_beta", "v2v3").inputs == ('alpha', 'beta', 'lam') + assert wcs.get_transform("alpha_beta", "v2v3").outputs == ('v2', 'v3', 'lam') + assert wcs.get_transform("v2v3", "v2v3vacorr").inputs == ('v2', 'v3', 'lam') + assert wcs.get_transform("v2v3", "v2v3vacorr").outputs == ('v2', 'v3', 'lam') + assert wcs.get_transform("v2v3vacorr", "world").inputs == ('v2', 'v3', 'lam') + assert wcs.get_transform("v2v3vacorr", "world").outputs == ('ra', 'dec', 'lam') + + +def test_transform_metadata_lrs(create_hdul_lrs): + wcs = create_datamodel(create_hdul_lrs).meta.wcs + + assert wcs.get_transform("detector", "v2v3").inputs == ('x', 'y') + assert wcs.get_transform("detector", "v2v3").outputs == ('v2', 'v3', 'lam') + assert wcs.get_transform("v2v3", "v2v3vacorr").inputs == ('v2', 'v3', 'lam') + assert wcs.get_transform("v2v3", "v2v3vacorr").outputs == ('v2', 'v3', 'lam') + assert wcs.get_transform("v2v3vacorr", "world").inputs == ('v2', 'v3', 'lam') + assert wcs.get_transform("v2v3vacorr", "world").outputs == ('ra', 'dec', 'lam') + + # MRS test reference data mrs_ref_data = { '1A': {'x': np.array([76.0, 354.0]), diff --git a/jwst/assign_wcs/tests/test_nircam.py b/jwst/assign_wcs/tests/test_nircam.py index 3ee141f0f9..b9c7349cdd 100644 --- a/jwst/assign_wcs/tests/test_nircam.py +++ b/jwst/assign_wcs/tests/test_nircam.py @@ -27,7 +27,7 @@ nircam_wfss_frames = ['grism_detector', 'detector', 'v2v3', 'v2v3vacorr', 'world'] -nircam_tsgrism_frames = ['grism_detector', 'direct_image', 'v2v3', 'v2v3vacorr', 'world'] +nircam_tsgrism_frames = ['grism_detector', 'detector', 'v2v3', 'v2v3vacorr', 'world'] nircam_imaging_frames = ['detector', 'v2v3', 'v2v3vacorr', 'world'] @@ -90,6 +90,7 @@ def create_wfss_wcs(pupil, filtername='F444W'): return wcsobj +@pytest.fixture def create_imaging_wcs(): hdul = create_hdul() image = ImageModel(hdul) @@ -202,12 +203,12 @@ def test_traverse_wfss_grisms(): traverse_wfss_trace(pupil) -def test_traverse_tso_grism(create_tso_wcs): - """Make sure that the TSO dispersion polynomials are reversable. - All assert statements are in pixel space so 1/1000 px seems easily acceptable""" - wcsobj = create_tso_wcs - detector_to_grism = wcsobj.get_transform('direct_image', 'grism_detector') - grism_to_detector = wcsobj.get_transform('grism_detector', 'direct_image') +@pytest.mark.xfail(reason="Fails due to V2 NIRCam specwcs ref files delivered to CRDS") +def test_traverse_tso_grism(): + """Make sure that the TSO dispersion polynomials are reversable.""" + wcsobj = create_tso_wcs() + detector_to_grism = wcsobj.get_transform('detector', 'grism_detector') + grism_to_detector = wcsobj.get_transform('grism_detector', 'detector') # TSGRISM always has same source locations # takes x,y,order -> ra, dec, wave, order @@ -227,9 +228,9 @@ def test_traverse_tso_grism(create_tso_wcs): # assert np.isclose(y, wcs_tso_kw['yref_sci']) -def test_imaging_frames(): +def test_imaging_frames(create_imaging_wcs): """Verify the available imaging mode reference frames.""" - wcsobj = create_imaging_wcs() + wcsobj = create_imaging_wcs available_frames = wcsobj.available_frames assert all([a == b for a, b in zip(nircam_imaging_frames, available_frames)]) @@ -244,3 +245,31 @@ def test_wfss_sip(): util.wfss_imaging_wcs(wfss_model, nircam.imaging, bbox=((1, 1024), (1, 1024))) for key in ['a_order', 'b_order', 'crpix1', 'crpix2', 'crval1', 'crval2', 'cd1_1']: assert key in wfss_model.meta.wcsinfo.instance + + +def test_transform_metadata_imaging(create_imaging_wcs): + wcsobj = create_imaging_wcs + assert wcsobj.get_transform("detector", "v2v3").inputs == ('x', 'y') + assert wcsobj.get_transform("detector", "v2v3").outputs == ('v2', 'v3') + assert wcsobj.get_transform("v2v3", "v2v3vacorr").inputs == ('v2', 'v3') + assert wcsobj.get_transform("v2v3", "v2v3vacorr").outputs == ('v2', 'v3') + assert wcsobj.get_transform("v2v3vacorr", "world").inputs == ('v2', 'v3') + assert wcsobj.get_transform("v2v3vacorr", "world").outputs == ('ra', 'dec') + + +@pytest.mark.parametrize('exptype', ['tso', 'wfss']) +def test_transform_metadata_grism(exptype): + if exptype == "tso": + wcsobj = create_tso_wcs() + assert wcsobj.get_transform("grism_detector", "detector").inputs == ('x', 'y', 'order') + elif exptype == "wfss": + wcsobj = create_wfss_wcs('GRISMR') + assert wcsobj.get_transform("grism_detector", "detector").inputs == ('x', 'y', 'x0', 'y0', 'order') + + assert wcsobj.get_transform("grism_detector", "detector").outputs == ('x_direct', 'y_direct', 'wavelength', 'order') + assert wcsobj.get_transform("detector", "v2v3").inputs == ('x_direct', 'y_direct', 'wavelength', 'order') + assert wcsobj.get_transform("detector", "v2v3").outputs == ('v2', 'v3', 'wavelength', 'order') + assert wcsobj.get_transform("v2v3", "v2v3vacorr").inputs == ('v2', 'v3', 'wavelength', 'order') + assert wcsobj.get_transform("v2v3", "v2v3vacorr").outputs == ('v2', 'v3', 'wavelength', 'order') + assert wcsobj.get_transform("v2v3vacorr", "world").inputs == ('v2', 'v3', 'wavelength', 'order') + assert wcsobj.get_transform("v2v3vacorr", "world").outputs == ('ra', 'dec', 'wavelength', 'order') diff --git a/jwst/assign_wcs/tests/test_niriss.py b/jwst/assign_wcs/tests/test_niriss.py index 8f6fa7ab69..752af25f25 100644 --- a/jwst/assign_wcs/tests/test_niriss.py +++ b/jwst/assign_wcs/tests/test_niriss.py @@ -184,3 +184,28 @@ def test_wfss_sip(): util.wfss_imaging_wcs(wfss_model, niriss.imaging, max_pix_error=0.05, bbox=((1, 1024), (1, 1024))) for key in ['a_order', 'b_order', 'crpix1', 'crpix2', 'crval1', 'crval2', 'cd1_1']: assert key in wfss_model.meta.wcsinfo.instance + + + +def test_transform_metadata_imaging(): + + wcs = create_imaging_wcs('F200W') + assert wcs.get_transform("detector", "v2v3").inputs == ('x', 'y') + assert wcs.get_transform("detector", "v2v3").outputs == ('v2', 'v3') + assert wcs.get_transform("v2v3", "v2v3vacorr").inputs == ('v2', 'v3') + assert wcs.get_transform("v2v3", "v2v3vacorr").outputs == ('v2', 'v3') + assert wcs.get_transform("v2v3vacorr", "world").inputs == ('v2', 'v3') + assert wcs.get_transform("v2v3vacorr", "world").outputs == ('ra', 'dec') + + +def test_transform_metadata_wfss(): + + wcs = create_wfss_wcs('GR150R') + assert wcs.get_transform("grism_detector", "detector").inputs == ('x', 'y', 'x0', 'y0', 'order') + assert wcs.get_transform("grism_detector", "detector").outputs == ('x_direct', 'y_direct', 'wavelength', 'order') + assert wcs.get_transform("detector", "v2v3").inputs == ('x_direct', 'y_direct', 'wavelength', 'order') + assert wcs.get_transform("detector", "v2v3").outputs == ('v2', 'v3', 'wavelength', 'order') + assert wcs.get_transform("v2v3", "v2v3vacorr").inputs == ('v2', 'v3', 'wavelength', 'order') + assert wcs.get_transform("v2v3", "v2v3vacorr").outputs == ('v2', 'v3', 'wavelength', 'order') + assert wcs.get_transform("v2v3vacorr", "world").inputs == ('v2', 'v3', 'wavelength', 'order') + assert wcs.get_transform("v2v3vacorr", "world").outputs == ('ra', 'dec', 'wavelength', 'order') diff --git a/jwst/assign_wcs/tests/test_nirspec.py b/jwst/assign_wcs/tests/test_nirspec.py index a517916835..67972fc24d 100644 --- a/jwst/assign_wcs/tests/test_nirspec.py +++ b/jwst/assign_wcs/tests/test_nirspec.py @@ -162,6 +162,25 @@ def test_nirspec_imaging(): # Test evaluating the WCS im.meta.wcs(1, 2) + # ensure transform metadata looks correct + # available frames ['detector', 'sca', 'gwa', 'msa', 'oteip', 'v2v3', 'v2v3vacorr', 'world'] + # see technical report JWST-STScI-005921, SM-12 for details + assert w.get_transform('detector', 'sca').inputs == ('x', 'y') + assert w.get_transform('detector', 'sca').outputs == ('x', 'y') + assert w.get_transform('sca', 'gwa').inputs == ('x', 'y') + assert w.get_transform('sca', 'gwa').outputs == ('x', 'y', 'z') + assert w.get_transform('gwa', 'msa').inputs == ('x', 'y', 'z') + assert w.get_transform('gwa', 'msa').outputs == ('x', 'y', 'lam') + assert w.get_transform('msa', 'oteip').inputs == ('x', 'y', 'lam') + assert w.get_transform('msa', 'oteip').outputs == ('xan', 'yan') + assert w.get_transform('oteip', 'v2v3').inputs == ('xan', 'yan') + assert w.get_transform('oteip', 'v2v3').outputs == ("v2", "v3") + assert w.get_transform('v2v3', 'v2v3vacorr').inputs == ("v2", "v3") + assert w.get_transform('v2v3', 'v2v3vacorr').outputs == ("v2", "v3") + assert w.get_transform('v2v3vacorr', 'world').inputs == ("v2", "v3") + assert w.get_transform('v2v3vacorr', 'world').outputs == ("ra", "dec") + + def test_nirspec_ifu_against_esa(wcs_ifu_grating): """ @@ -486,6 +505,26 @@ def test_functional_fs_msa(mode): im.meta.wcs = w slit_wcs = nirspec.nrs_wcs_set_input(im, 1) + # add metadata tests for msa, which cover fs too because both call slitlets_wcs + assert w.get_transform('detector', 'sca').inputs == ('x', 'y') + assert w.get_transform('detector', 'sca').outputs == ('x', 'y') + assert w.get_transform('sca', 'gwa').inputs == ("?", "?", 'x', 'y') # I think these are x,y of the slit center + assert w.get_transform('sca', 'gwa').outputs == ("?", "?", 'alpha', 'beta', 'gamma') + assert w.get_transform('gwa', 'slit_frame').inputs == ('name', 'alpha', 'beta', 'gamma') + assert w.get_transform('gwa', 'slit_frame').outputs == ('name', 'x_slit', 'y_slit', 'lam') + assert w.get_transform('slit_frame', 'slicer').inputs == ('x_slit', 'y_slit', 'lam') + assert w.get_transform('slit_frame', 'slicer').outputs == ('x_slicer', 'y_slicer', 'lam') + assert w.get_transform('slicer', 'msa_frame').inputs == ('x_slicer', 'y_slicer', 'lam') + assert w.get_transform('slicer', 'msa_frame').outputs == ('x_msa', 'y_msa', 'lam') + assert w.get_transform('msa_frame', 'oteip').inputs == ('x_msa', 'y_msa', 'lam') + assert w.get_transform('msa_frame', 'oteip').outputs == ('xan', 'yan', 'lam') + assert w.get_transform('oteip', 'v2v3').inputs == ('xan', 'yan', 'lam') + assert w.get_transform('oteip', 'v2v3').outputs == ("v2", "v3", 'lam') + assert w.get_transform('v2v3', 'v2v3vacorr').inputs == ("v2", "v3", 'lam') + assert w.get_transform('v2v3', 'v2v3vacorr').outputs == ("v2", "v3", 'lam') + assert w.get_transform('v2v3vacorr', 'world').inputs == ("v2", "v3", 'lam') + assert w.get_transform('v2v3vacorr', 'world').outputs == ("ra", "dec", 'lam') + ins_file = get_file_path(model_file) ins_tab = table.Table.read(ins_file, format='ascii') @@ -748,6 +787,28 @@ def test_functional_ifu_grating(wcs_ifu_grating): assert_allclose(v2, ins_tab['xV2V3']) assert_allclose(v3, ins_tab['yV2V3']) + # test transform metadata for slit_wcs + # available frames ['detector', 'sca', 'gwa', 'slit_frame', 'slicer', 'msa_frame', 'oteip', 'v2v3', 'v2v3vacorr', 'world'] + assert slit_wcs.get_transform('detector', 'sca').inputs == ('x', 'y') + assert slit_wcs.get_transform('detector', 'sca').outputs == ('x', 'y') + assert slit_wcs.get_transform('sca', 'gwa').inputs == ('x', 'y') + assert slit_wcs.get_transform('sca', 'gwa').outputs == ('alpha', 'beta', 'gamma') + assert slit_wcs.get_transform('gwa', 'slit_frame').inputs == ('alpha', 'beta', 'gamma') + assert slit_wcs.get_transform('gwa', 'slit_frame').outputs == ('x_slit', 'y_slit', 'lam') + assert slit_wcs.get_transform('slit_frame', 'slicer').inputs == ('x_slit', 'y_slit', 'lam') + assert slit_wcs.get_transform('slit_frame', 'slicer').outputs == ('x_slicer', 'y_slicer', 'lam') + assert slit_wcs.get_transform('slicer', 'msa_frame').inputs == ('x_slicer', 'y_slicer', 'lam') + assert slit_wcs.get_transform('slicer', 'msa_frame').outputs == ('x_msa', 'y_msa', 'lam') + assert slit_wcs.get_transform('msa_frame', 'oteip').inputs == ('x_msa', 'y_msa', 'lam') + assert slit_wcs.get_transform('msa_frame', 'oteip').outputs == ('xan', 'yan', 'lam') + assert slit_wcs.get_transform('oteip', 'v2v3').inputs == ('xan', 'yan', 'lam') + assert slit_wcs.get_transform('oteip', 'v2v3').outputs == ("v2", "v3", 'lam') + assert slit_wcs.get_transform('v2v3', 'v2v3vacorr').inputs == ("v2", "v3", 'lam') + assert slit_wcs.get_transform('v2v3', 'v2v3vacorr').outputs == ("v2", "v3", 'lam') + assert slit_wcs.get_transform('v2v3vacorr', 'world').inputs == ("v2", "v3", 'lam') + assert slit_wcs.get_transform('v2v3vacorr', 'world').outputs == ("ra", "dec", 'lam') + + def test_functional_ifu_prism(): """Compare Nirspec instrument model with IDT model for IFU prism.""" @@ -863,6 +924,31 @@ def test_functional_ifu_prism(): assert_allclose(v2, ins_tab['xV2V3']) assert_allclose(v3, ins_tab['yV2V3']) + # test transform metadata for wcs + # available frames ['detector', 'sca', 'gwa', 'slit_frame', 'slicer', 'msa_frame', 'oteip', 'v2v3', 'v2v3vacorr', 'world'] + # read-only - must extract single slit id to execute the wcs + breakpoint() + assert w.get_transform('detector', 'sca').inputs == ('x_detector', 'y_detector') + assert w.get_transform('detector', 'sca').outputs == ('x_sca', 'y_sca') + assert w.get_transform('sca', 'gwa').inputs == ('null', 'null', 'x_sca', 'y_sca') + assert w.get_transform('sca', 'gwa').outputs == ('null', 'null', 'alpha', 'beta', 'gamma') + assert w.get_transform('gwa', 'slit_frame').inputs == ('name', 'alpha', 'beta', 'gamma') + assert w.get_transform('gwa', 'slit_frame').outputs == ('name', 'x_slit', 'y_slit', 'lam') + assert w.get_transform('slit_frame', 'slicer').inputs == ('x_slit', 'y_slit', 'lam') + assert w.get_transform('slit_frame', 'slicer').outputs == ('x_slicer', 'y_slicer', 'lam') + assert w.get_transform('slicer', 'msa_frame').inputs == ('x_slicer', 'y_slicer', 'lam') + assert w.get_transform('slicer', 'msa_frame').outputs == ('x_msa', 'y_msa', 'lam') + assert w.get_transform('msa_frame', 'oteip').inputs == ('x_msa', 'y_msa', 'lam') + assert w.get_transform('msa_frame', 'oteip').outputs == ('xan', 'yan', 'lam') + assert w.get_transform('oteip', 'v2v3').inputs == ('xan', 'yan', 'lam') + assert w.get_transform('oteip', 'v2v3').outputs == ("v2", "v3", 'lam') + assert w.get_transform('v2v3', 'v2v3vacorr').inputs == ("v2", "v3", 'lam') + assert w.get_transform('v2v3', 'v2v3vacorr').outputs == ("v2", "v3", 'lam') + assert w.get_transform('v2v3vacorr', 'world').inputs == ("v2", "v3", 'lam') + assert w.get_transform('v2v3vacorr', 'world').outputs == ("ra", "dec", 'lam') + + # can they be serialized? when this is done are the names restored? + def test_ifu_bbox(): bbox = {0: ((122.0908542999878, 1586.2584665188083), diff --git a/jwst/assign_wcs/util.py b/jwst/assign_wcs/util.py index f9f90918b4..118644638e 100644 --- a/jwst/assign_wcs/util.py +++ b/jwst/assign_wcs/util.py @@ -1503,3 +1503,25 @@ def get_wcs_reference_files(datamodel): else: refs[reftype] = val return refs + + +def wl_identity(): + """ + Takes inputs (wavelength,), does nothing, returns (wavelength,). + """ + identity = astmodels.Identity(1) + identity.inputs = ("lam",) + identity.outputs = ("lam",) + identity.name = "WlIdentity" + return identity + + +def wl_order_identity(): + """ + Takes inputs (wavelength, order), does nothing, returns (wavelength, order). + """ + identity = astmodels.Identity(2) + identity.inputs = ('wavelength', 'order') + identity.outputs = ('wavelength', 'order') + identity.name = 'wl_order_identity' + return identity \ No newline at end of file From cedb2e2209a8982e22e3f3239844e2f710a53d35 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 6 May 2024 09:40:52 -0400 Subject: [PATCH 2/4] intermediate progress, fixes except nirspec and mtwcs --- jwst/assign_mtwcs/tests/test_mtwcs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jwst/assign_mtwcs/tests/test_mtwcs.py b/jwst/assign_mtwcs/tests/test_mtwcs.py index a33b848872..0f6f56b1e0 100644 --- a/jwst/assign_mtwcs/tests/test_mtwcs.py +++ b/jwst/assign_mtwcs/tests/test_mtwcs.py @@ -20,3 +20,8 @@ def test_mt_multislit(): assert result[0].slits[0].meta.wcs.output_frame.name == 'moving_target' assert len(result[1].slits) == 1 assert result[1].slits[0].meta.wcs.output_frame.name == 'moving_target' + + # test wcs transform metadata + wcs = result[0].slits[0].meta.wcs + print(wcs.available_frames) + breakpoint() \ No newline at end of file From 09be0a7fafda512d8892e680f687bed3daf5e04b Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 7 May 2024 09:03:04 -0400 Subject: [PATCH 3/4] intermediate commit, still working on nirspec --- jwst/assign_mtwcs/moving_target_wcs.py | 5 ++- jwst/assign_mtwcs/tests/test_mtwcs.py | 5 --- jwst/assign_wcs/miri.py | 4 +- jwst/assign_wcs/nircam.py | 10 +++-- jwst/assign_wcs/niriss.py | 5 ++- jwst/assign_wcs/nirspec.py | 57 +++++++++++++++----------- jwst/assign_wcs/tests/test_miri.py | 4 +- jwst/assign_wcs/tests/test_nircam.py | 6 +-- jwst/assign_wcs/tests/test_niriss.py | 4 +- jwst/assign_wcs/tests/test_nirspec.py | 47 ++++++++++----------- 10 files changed, 77 insertions(+), 70 deletions(-) diff --git a/jwst/assign_mtwcs/moving_target_wcs.py b/jwst/assign_mtwcs/moving_target_wcs.py index 7f61b84e9c..4cf64b5d4c 100644 --- a/jwst/assign_mtwcs/moving_target_wcs.py +++ b/jwst/assign_mtwcs/moving_target_wcs.py @@ -10,9 +10,10 @@ import logging from copy import deepcopy import numpy as np -from astropy.modeling.models import Shift, Identity +from astropy.modeling.models import Shift from gwcs import WCS from gwcs import coordinate_frames as cf +from jwst.assign_wcs.util import wl_identity from stdatamodels.jwst import datamodels @@ -99,7 +100,7 @@ def add_mt_frame(wcs, ra_average, dec_average, mt_ra, mt_dec): if isinstance(mt, cf.CelestialFrame): transform_to_mt = Shift(rdel) & Shift(ddel) elif isinstance(mt, cf.CompositeFrame): - transform_to_mt = Shift(rdel) & Shift(ddel) & Identity(1) + transform_to_mt = Shift(rdel) & Shift(ddel) & wl_identity() else: raise ValueError("Unrecognized coordinate frame.") diff --git a/jwst/assign_mtwcs/tests/test_mtwcs.py b/jwst/assign_mtwcs/tests/test_mtwcs.py index 0f6f56b1e0..a33b848872 100644 --- a/jwst/assign_mtwcs/tests/test_mtwcs.py +++ b/jwst/assign_mtwcs/tests/test_mtwcs.py @@ -20,8 +20,3 @@ def test_mt_multislit(): assert result[0].slits[0].meta.wcs.output_frame.name == 'moving_target' assert len(result[1].slits) == 1 assert result[1].slits[0].meta.wcs.output_frame.name == 'moving_target' - - # test wcs transform metadata - wcs = result[0].slits[0].meta.wcs - print(wcs.available_frames) - breakpoint() \ No newline at end of file diff --git a/jwst/assign_wcs/miri.py b/jwst/assign_wcs/miri.py index 28dc3a5d4e..cfeed886d1 100644 --- a/jwst/assign_wcs/miri.py +++ b/jwst/assign_wcs/miri.py @@ -363,7 +363,7 @@ def lrs_distortion(input_model, reference_files): dettotel.bounding_box = bb_sub[::-1] dettotel.name = "lrs_distortion" - dettotel.inputs = ("x", "y") + dettotel.inputs = ("x_direct", "y_direct") dettotel.outputs = ("v2", "v3", "lam") return dettotel @@ -482,7 +482,7 @@ def detector_to_abl(input_model, reference_files): with WavelengthrangeModel(reference_files['wavelengthrange']) as f: wr = dict(zip(f.waverange_selector, f.wavelengthrange)) - det_labels = ('x', 'y') + det_labels = ('x_direct', 'y_direct') abl_labels = ('alpha', 'beta', 'lam') ch_dict = {} for c in channel: diff --git a/jwst/assign_wcs/nircam.py b/jwst/assign_wcs/nircam.py index 26dc770a35..5afd624816 100644 --- a/jwst/assign_wcs/nircam.py +++ b/jwst/assign_wcs/nircam.py @@ -153,7 +153,7 @@ def imaging_distortion(input_model, reference_files): transform.bounding_box = bbox transform.name = "imaging_distortion" - transform.inputs = ('x', 'y') + transform.inputs = ('x_direct', 'y_direct') transform.outputs = ('v2', 'v3') return transform @@ -279,8 +279,9 @@ def tsgrism(input_model, reference_files): sub2direct = (Mapping((0, 1, 0, 1, 2)) | (wl_order_identity() & xcenter & ycenter & Identity(1)) | det2det) - sub2direct.name = "grism2image" - sub2direct.inputs = ('x', 'y', 'order') + sub2direct.name = "grism_to_image" + sub2direct.inputs = ('x_grism', 'y_grism', 'order') + sub2direct.outputs = ('x_direct', 'y_direct', 'wavelength', 'order') # take us from full frame detector to v2v3 distortion = imaging_distortion(input_model, reference_files) & wl_order_identity() @@ -421,6 +422,9 @@ def wfss(input_model, reference_files): velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys) log.info("Added Barycentric velocity correction: {}".format(velocity_corr[1].amplitude.value)) det2det = det2det | Mapping((0, 1, 2, 3)) | Identity(2) & velocity_corr & Identity(1) + det2det.name = "grism_dispersion" + det2det.inputs = ('x_grism', 'y_grism', 'x0', 'y0', 'order') + det2det.outputs = ('x_direct', 'y_direct', 'wavelength', 'order') # create the pipeline to construct a WCS object for the whole image # which can translate ra,dec to image frame reference pixels diff --git a/jwst/assign_wcs/niriss.py b/jwst/assign_wcs/niriss.py index f1ec23cc37..50777db266 100644 --- a/jwst/assign_wcs/niriss.py +++ b/jwst/assign_wcs/niriss.py @@ -313,7 +313,7 @@ def imaging_distortion(input_model, reference_files): else: distortion.bounding_box = bbox - distortion.inputs = ('x', 'y') + distortion.inputs = ('x_direct', 'y_direct') distortion.outputs = ('v2', 'v3') distortion.name = "imaging_distortion" return distortion @@ -450,6 +450,9 @@ def wfss(input_model, reference_files): velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys) log.info("Added Barycentric velocity correction: {}".format(velocity_corr[1].amplitude.value)) det2det = det2det | Mapping((0, 1, 2, 3)) | Identity(2) & velocity_corr & Identity(1) + det2det.name = "grism_dispersion" + det2det.inputs = ('x_grism', 'y_grism', 'x0', 'y0', 'order') + det2det.outputs = ('x_direct', 'y_direct', 'wavelength', 'order') # create the pipeline to construct a WCS object for the whole image # which can translate ra,dec to image frame reference pixels diff --git a/jwst/assign_wcs/nirspec.py b/jwst/assign_wcs/nirspec.py index 3d6b3b0bb0..ed2fc6bdf9 100644 --- a/jwst/assign_wcs/nirspec.py +++ b/jwst/assign_wcs/nirspec.py @@ -88,8 +88,6 @@ def imaging(input_model, reference_files): det2gwa = detector_to_gwa(reference_files, input_model.meta.instrument.detector, disperser) gwa_through = Const1D(-1) * Identity(1) & Const1D(-1) * Identity(1) & Identity(1) - gwa_through.inputs = ('x', 'y', 'z') - gwa_through.outputs = ('x', 'y', 'z') angles = [disperser['theta_x'], disperser['theta_y'], disperser['theta_z'], disperser['tilt_y']] @@ -99,8 +97,6 @@ def imaging(input_model, reference_files): col_model = CollimatorModel(reference_files['collimator']) col = col_model.model col_model.close() - col.inputs = ('x', 'y') - col.outputs = ('x', 'y') # Get the default spectral order and wavelength range and record them in the model. sporder, wrange = get_spectral_order_wrange(input_model, reference_files['wavelengthrange']) @@ -111,11 +107,11 @@ def imaging(input_model, reference_files): lam = wrange[0] + (wrange[1] - wrange[0]) * .5 lam_model = Mapping((0, 1, 1)) | Identity(2) & Const1D(lam) - lam_model.inputs = ('x', 'y') - lam_model.outputs = ('x', 'y', 'lam') gwa2msa = gwa_through | rotation | dircos2unitless | col | lam_model gwa2msa.name = "gwa_to_msa" + gwa2msa.inputs = ('alpha', 'beta', 'gamma') + gwa2msa.outputs = ('x_msa', 'y_msa', 'lam') gwa2msa.inverse = col.inverse | dircos2unitless.inverse | rotation.inverse | gwa_through # Create coordinate frames in the NIRSPEC WCS pipeline @@ -126,7 +122,7 @@ def imaging(input_model, reference_files): msa2ote = msa_to_oteip(reference_files) msa2oteip = msa2ote | Mapping((0, 1), n_inputs=3) msa2oteip.name = "msa_to_oteip" - msa2oteip.inputs = ('x', 'y', 'lam') + msa2oteip.inputs = ('x_msa', 'y_msa', 'lam') msa2oteip.outputs = ('xan', 'yan') map1 = Mapping((0, 1, 0, 1)) minv = msa2ote.inverse @@ -169,6 +165,15 @@ def imaging(input_model, reference_files): return imaging_pipeline +def null_identity(n): + '''Make placeholder Identity transform to pass in two new empty dimesnsions''' + transform = Identity(n) + transform.name = "null_identity" + transform.inputs = ('null',)*n + transform.outputs = ('null',)*n + return transform + + def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): """ The Nirspec IFU WCS pipeline. @@ -228,8 +233,7 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): dms2detector = dms_to_sca(input_model) # DETECTOR to GWA transform - # what are the two additional inputs? - det2gwa = Identity(2) & detector_to_gwa(reference_files, + det2gwa = null_identity(2) & detector_to_gwa(reference_files, input_model.meta.instrument.detector, disperser) @@ -279,7 +283,7 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): # in the whole pipeline) to microns (which is the expected output) # # "detector", "gwa", "slit_frame", "msa_frame", "oteip", "v2v3", "world" - + breakpoint() pipeline = [(det, dms2detector), (sca, det2gwa), (gwa, gwa2slit), @@ -350,7 +354,7 @@ def slitlets_wcs(input_model, reference_files, open_slits_id): dms2detector = dms_to_sca(input_model) # DETECTOR to GWA transform - det2gwa = Identity(2) & detector_to_gwa(reference_files, + det2gwa = null_identity(2) & detector_to_gwa(reference_files, input_model.meta.instrument.detector, disperser) @@ -392,7 +396,7 @@ def slitlets_wcs(input_model, reference_files, open_slits_id): ) & wl_identity() # V2, V3 to sky - tel2sky = pointing.v23tosky(input_model) & Identity(1) + tel2sky = pointing.v23tosky(input_model) & wl_identity() tel2sky.name = "v2v3_to_sky" msa_pipeline = [(det, dms2detector), @@ -850,14 +854,12 @@ def ifuslit_to_slicer(slits, reference_files): msa_transform = slitdata_model | ifuslicer_model msa_transform.name = "ifuslit_to_slicer" msa_transform.inputs = ('x_slit', 'y_slit') - msa_transform.outputs = ('x_slit', 'y_slit') + msa_transform.outputs = ('x_slicer', 'y_slicer') models.append(msa_transform) ifuslicer.close() transform = Slit2Msa(slits, models) transform.name = "ifuslit_to_slicer" - #transform.inputs = ('name', 'x_slit', 'y_slit') - #transform.outputs = ('x_msa', 'y_msa') return transform @@ -874,9 +876,9 @@ def slicer_to_msa(reference_files): slicer2fore_mapping.inverse = Identity(3) ifufore2fore_mapping = Identity(1) ifufore2fore_mapping.inverse = Mapping((0, 1, 2, 2)) - ifu_fore_transform = slicer2fore_mapping | ifufore & wl_identity() + ifu_fore_transform = slicer2fore_mapping | ifufore & Identity(1) ifu_fore_transform.name = "slicer_to_msa" - ifu_fore_transform.inputs = ('x', 'y', 'lam') + ifu_fore_transform.inputs = ('x_slicer', 'y_slicer', 'lam') ifu_fore_transform.outputs = ('x_msa', 'y_msa', 'lam') return ifu_fore_transform @@ -916,8 +918,8 @@ def slit_to_msa(open_slits, msafile): slitdata_model = get_slit_location_model(slitdata) msa_transform = slitdata_model | msa_model msa_transform.name = "slit_to_msa" - msa_transform.inputs = ('x', 'y', 'lam') - msa_transform.outputs = ('x', 'y', 'lam') + msa_transform.inputs = ('x_slit', 'y_slit') + msa_transform.outputs = ('x_msa', 'y_msa') models.append(msa_transform) slits.append(slit) msa.close() @@ -1009,7 +1011,7 @@ def gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range) ) # transform from ``msa_frame`` to ``gwa`` frame (before the GWA going from detector to sky). - msa2gwa_out = ifuslicer_transform & Identity(1) | ifupost_transform | collimator2gwa + msa2gwa_out = ifuslicer_transform & wl_identity() | ifupost_transform | collimator2gwa msa2bgwa = Mapping((0, 1, 2, 2)) | msa2gwa_out & Identity(1) | Mapping((3, 0, 1, 2)) | agreq bgwa2msa.inverse = msa2bgwa bgwa2msa.name = "gwa_to_ifuslit" @@ -1021,6 +1023,8 @@ def gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range) ifupost.close() transform = Gwa2Slit(slits, slit_models) transform.name = "gwa_to_ifuslit" + transform.inputs = ('name', 'alpha', 'beta', 'gamma') + #transform.outputs = ('x_slit', 'y_slit', 'lam') return transform @@ -1103,11 +1107,14 @@ def gwa_to_slit(open_slits, input_model, disperser, # msa to before_gwa msa2bgwa = msa2gwa & Identity(1) | Mapping((3, 0, 1, 2)) | agreq bgwa2msa.inverse = msa2bgwa + bgwa2msa.name = "gwa_to_slit" + bgwa2msa.inputs = ('alpha', 'beta', 'gamma') + bgwa2msa.outputs = ('x_slit', 'y_slit', 'lam') slit_models.append(bgwa2msa) slits.append(slit) msa.close() transform = Gwa2Slit(slits, slit_models) - transform.name = "gwa_to_slit" + transform.inputs = ('name', 'alpha', 'beta', 'gamma') return transform @@ -1221,7 +1228,7 @@ def detector_to_gwa(reference_files, detector, disperser): ''' model = fpa | camera | u2dircos | rotation model.name = 'sca_to_gwa' - # input names already handled in stdatamodels Rotation3DToGWA + model.inputs = ('x_sca', 'y_sca') model.outputs = ('alpha', 'beta', 'gamma') return model @@ -1248,8 +1255,8 @@ def dms_to_sca(input_model): elif detector == 'NRS1': model = models.Identity(2) dms2sca = subarray2full | model - dms2sca.inputs = ('x', 'y') - dms2sca.outputs = ('x', 'y') + dms2sca.inputs = ('x_detector', 'y_detector') + dms2sca.outputs = ('x_sca', 'y_sca') dms2sca.name = 'dms_to_sca' return dms2sca @@ -1502,7 +1509,7 @@ def msa_to_oteip(reference_files): msa2fore_mapping.inverse = Identity(3) transform = msa2fore_mapping | (fore & wl_identity()) transform.name = "msa_to_oteip" - transform.inputs = ('x', 'y', 'lam') + transform.inputs = ('x_msa', 'y_msa', 'lam') transform.outputs = ('xan', 'yan', 'lam') return transform diff --git a/jwst/assign_wcs/tests/test_miri.py b/jwst/assign_wcs/tests/test_miri.py index f548a94ae4..18dc956061 100644 --- a/jwst/assign_wcs/tests/test_miri.py +++ b/jwst/assign_wcs/tests/test_miri.py @@ -190,7 +190,7 @@ def test_transform_metadata_mrs(): hdul = create_hdul(detector="MIRIFULONG", channel="34", band="MEDIUM") wcs = create_datamodel(hdul).meta.wcs - assert wcs.get_transform("detector", "alpha_beta").inputs == ('x', 'y') + assert wcs.get_transform("detector", "alpha_beta").inputs == ('x_direct', 'y_direct') assert wcs.get_transform("detector", "alpha_beta").outputs == ('alpha', 'beta', 'lam') assert wcs.get_transform("alpha_beta", "v2v3").inputs == ('alpha', 'beta', 'lam') assert wcs.get_transform("alpha_beta", "v2v3").outputs == ('v2', 'v3', 'lam') @@ -203,7 +203,7 @@ def test_transform_metadata_mrs(): def test_transform_metadata_lrs(create_hdul_lrs): wcs = create_datamodel(create_hdul_lrs).meta.wcs - assert wcs.get_transform("detector", "v2v3").inputs == ('x', 'y') + assert wcs.get_transform("detector", "v2v3").inputs == ('x_direct', 'y_direct') assert wcs.get_transform("detector", "v2v3").outputs == ('v2', 'v3', 'lam') assert wcs.get_transform("v2v3", "v2v3vacorr").inputs == ('v2', 'v3', 'lam') assert wcs.get_transform("v2v3", "v2v3vacorr").outputs == ('v2', 'v3', 'lam') diff --git a/jwst/assign_wcs/tests/test_nircam.py b/jwst/assign_wcs/tests/test_nircam.py index b9c7349cdd..daf7bacb09 100644 --- a/jwst/assign_wcs/tests/test_nircam.py +++ b/jwst/assign_wcs/tests/test_nircam.py @@ -249,7 +249,7 @@ def test_wfss_sip(): def test_transform_metadata_imaging(create_imaging_wcs): wcsobj = create_imaging_wcs - assert wcsobj.get_transform("detector", "v2v3").inputs == ('x', 'y') + assert wcsobj.get_transform("detector", "v2v3").inputs == ('x_direct', 'y_direct') assert wcsobj.get_transform("detector", "v2v3").outputs == ('v2', 'v3') assert wcsobj.get_transform("v2v3", "v2v3vacorr").inputs == ('v2', 'v3') assert wcsobj.get_transform("v2v3", "v2v3vacorr").outputs == ('v2', 'v3') @@ -261,10 +261,10 @@ def test_transform_metadata_imaging(create_imaging_wcs): def test_transform_metadata_grism(exptype): if exptype == "tso": wcsobj = create_tso_wcs() - assert wcsobj.get_transform("grism_detector", "detector").inputs == ('x', 'y', 'order') + assert wcsobj.get_transform("grism_detector", "detector").inputs == ('x_grism', 'y_grism', 'order') elif exptype == "wfss": wcsobj = create_wfss_wcs('GRISMR') - assert wcsobj.get_transform("grism_detector", "detector").inputs == ('x', 'y', 'x0', 'y0', 'order') + assert wcsobj.get_transform("grism_detector", "detector").inputs == ('x_grism', 'y_grism', 'x0', 'y0', 'order') assert wcsobj.get_transform("grism_detector", "detector").outputs == ('x_direct', 'y_direct', 'wavelength', 'order') assert wcsobj.get_transform("detector", "v2v3").inputs == ('x_direct', 'y_direct', 'wavelength', 'order') diff --git a/jwst/assign_wcs/tests/test_niriss.py b/jwst/assign_wcs/tests/test_niriss.py index 752af25f25..997718abda 100644 --- a/jwst/assign_wcs/tests/test_niriss.py +++ b/jwst/assign_wcs/tests/test_niriss.py @@ -190,7 +190,7 @@ def test_wfss_sip(): def test_transform_metadata_imaging(): wcs = create_imaging_wcs('F200W') - assert wcs.get_transform("detector", "v2v3").inputs == ('x', 'y') + assert wcs.get_transform("detector", "v2v3").inputs == ('x_direct', 'y_direct') assert wcs.get_transform("detector", "v2v3").outputs == ('v2', 'v3') assert wcs.get_transform("v2v3", "v2v3vacorr").inputs == ('v2', 'v3') assert wcs.get_transform("v2v3", "v2v3vacorr").outputs == ('v2', 'v3') @@ -201,7 +201,7 @@ def test_transform_metadata_imaging(): def test_transform_metadata_wfss(): wcs = create_wfss_wcs('GR150R') - assert wcs.get_transform("grism_detector", "detector").inputs == ('x', 'y', 'x0', 'y0', 'order') + assert wcs.get_transform("grism_detector", "detector").inputs == ('x_grism', 'y_grism', 'x0', 'y0', 'order') assert wcs.get_transform("grism_detector", "detector").outputs == ('x_direct', 'y_direct', 'wavelength', 'order') assert wcs.get_transform("detector", "v2v3").inputs == ('x_direct', 'y_direct', 'wavelength', 'order') assert wcs.get_transform("detector", "v2v3").outputs == ('v2', 'v3', 'wavelength', 'order') diff --git a/jwst/assign_wcs/tests/test_nirspec.py b/jwst/assign_wcs/tests/test_nirspec.py index 67972fc24d..cca9c73466 100644 --- a/jwst/assign_wcs/tests/test_nirspec.py +++ b/jwst/assign_wcs/tests/test_nirspec.py @@ -165,13 +165,13 @@ def test_nirspec_imaging(): # ensure transform metadata looks correct # available frames ['detector', 'sca', 'gwa', 'msa', 'oteip', 'v2v3', 'v2v3vacorr', 'world'] # see technical report JWST-STScI-005921, SM-12 for details - assert w.get_transform('detector', 'sca').inputs == ('x', 'y') - assert w.get_transform('detector', 'sca').outputs == ('x', 'y') - assert w.get_transform('sca', 'gwa').inputs == ('x', 'y') - assert w.get_transform('sca', 'gwa').outputs == ('x', 'y', 'z') - assert w.get_transform('gwa', 'msa').inputs == ('x', 'y', 'z') - assert w.get_transform('gwa', 'msa').outputs == ('x', 'y', 'lam') - assert w.get_transform('msa', 'oteip').inputs == ('x', 'y', 'lam') + assert w.get_transform('detector', 'sca').inputs == ('x_detector', 'y_detector') + assert w.get_transform('detector', 'sca').outputs == ('x_sca', 'y_sca') + assert w.get_transform('sca', 'gwa').inputs == ('x_sca', 'y_sca') + assert w.get_transform('sca', 'gwa').outputs == ('alpha', 'beta', 'gamma') + assert w.get_transform('gwa', 'msa').inputs == ('alpha', 'beta', 'gamma') + assert w.get_transform('gwa', 'msa').outputs == ('x_msa', 'y_msa', 'lam') + assert w.get_transform('msa', 'oteip').inputs == ('x_msa', 'y_msa', 'lam') assert w.get_transform('msa', 'oteip').outputs == ('xan', 'yan') assert w.get_transform('oteip', 'v2v3').inputs == ('xan', 'yan') assert w.get_transform('oteip', 'v2v3').outputs == ("v2", "v3") @@ -506,16 +506,14 @@ def test_functional_fs_msa(mode): slit_wcs = nirspec.nrs_wcs_set_input(im, 1) # add metadata tests for msa, which cover fs too because both call slitlets_wcs - assert w.get_transform('detector', 'sca').inputs == ('x', 'y') - assert w.get_transform('detector', 'sca').outputs == ('x', 'y') - assert w.get_transform('sca', 'gwa').inputs == ("?", "?", 'x', 'y') # I think these are x,y of the slit center - assert w.get_transform('sca', 'gwa').outputs == ("?", "?", 'alpha', 'beta', 'gamma') + assert w.get_transform('detector', 'sca').inputs == ('x_detector', 'y_detector') + assert w.get_transform('detector', 'sca').outputs == ('x_sca', 'y_sca') + assert w.get_transform('sca', 'gwa').inputs == ("null", "null", 'x_sca', 'y_sca') + assert w.get_transform('sca', 'gwa').outputs == ("null", "null", 'alpha', 'beta', 'gamma') assert w.get_transform('gwa', 'slit_frame').inputs == ('name', 'alpha', 'beta', 'gamma') assert w.get_transform('gwa', 'slit_frame').outputs == ('name', 'x_slit', 'y_slit', 'lam') - assert w.get_transform('slit_frame', 'slicer').inputs == ('x_slit', 'y_slit', 'lam') - assert w.get_transform('slit_frame', 'slicer').outputs == ('x_slicer', 'y_slicer', 'lam') - assert w.get_transform('slicer', 'msa_frame').inputs == ('x_slicer', 'y_slicer', 'lam') - assert w.get_transform('slicer', 'msa_frame').outputs == ('x_msa', 'y_msa', 'lam') + assert w.get_transform('slit_frame', 'msa_frame').inputs == ('name', 'x_slit', 'y_slit') + assert w.get_transform('slit_frame', 'msa_frame').outputs == ('x_msa', 'y_msa') assert w.get_transform('msa_frame', 'oteip').inputs == ('x_msa', 'y_msa', 'lam') assert w.get_transform('msa_frame', 'oteip').outputs == ('xan', 'yan', 'lam') assert w.get_transform('oteip', 'v2v3').inputs == ('xan', 'yan', 'lam') @@ -789,15 +787,15 @@ def test_functional_ifu_grating(wcs_ifu_grating): # test transform metadata for slit_wcs # available frames ['detector', 'sca', 'gwa', 'slit_frame', 'slicer', 'msa_frame', 'oteip', 'v2v3', 'v2v3vacorr', 'world'] - assert slit_wcs.get_transform('detector', 'sca').inputs == ('x', 'y') - assert slit_wcs.get_transform('detector', 'sca').outputs == ('x', 'y') - assert slit_wcs.get_transform('sca', 'gwa').inputs == ('x', 'y') + assert slit_wcs.get_transform('detector', 'sca').inputs == ('x_detector', 'y_detector') + assert slit_wcs.get_transform('detector', 'sca').outputs == ('x_sca', 'y_sca') + assert slit_wcs.get_transform('sca', 'gwa').inputs == ('x_sca', 'y_sca') assert slit_wcs.get_transform('sca', 'gwa').outputs == ('alpha', 'beta', 'gamma') assert slit_wcs.get_transform('gwa', 'slit_frame').inputs == ('alpha', 'beta', 'gamma') assert slit_wcs.get_transform('gwa', 'slit_frame').outputs == ('x_slit', 'y_slit', 'lam') - assert slit_wcs.get_transform('slit_frame', 'slicer').inputs == ('x_slit', 'y_slit', 'lam') - assert slit_wcs.get_transform('slit_frame', 'slicer').outputs == ('x_slicer', 'y_slicer', 'lam') - assert slit_wcs.get_transform('slicer', 'msa_frame').inputs == ('x_slicer', 'y_slicer', 'lam') + assert slit_wcs.get_transform('slit_frame', 'slicer').inputs == ('name', 'x_slit', 'y_slit') + assert slit_wcs.get_transform('slit_frame', 'slicer').outputs == ('name', 'x_slicer', 'y_slicer') + assert slit_wcs.get_transform('slicer', 'msa_frame').inputs == ('name', 'x_slicer', 'y_slicer') assert slit_wcs.get_transform('slicer', 'msa_frame').outputs == ('x_msa', 'y_msa', 'lam') assert slit_wcs.get_transform('msa_frame', 'oteip').inputs == ('x_msa', 'y_msa', 'lam') assert slit_wcs.get_transform('msa_frame', 'oteip').outputs == ('xan', 'yan', 'lam') @@ -927,16 +925,15 @@ def test_functional_ifu_prism(): # test transform metadata for wcs # available frames ['detector', 'sca', 'gwa', 'slit_frame', 'slicer', 'msa_frame', 'oteip', 'v2v3', 'v2v3vacorr', 'world'] # read-only - must extract single slit id to execute the wcs - breakpoint() assert w.get_transform('detector', 'sca').inputs == ('x_detector', 'y_detector') assert w.get_transform('detector', 'sca').outputs == ('x_sca', 'y_sca') assert w.get_transform('sca', 'gwa').inputs == ('null', 'null', 'x_sca', 'y_sca') assert w.get_transform('sca', 'gwa').outputs == ('null', 'null', 'alpha', 'beta', 'gamma') assert w.get_transform('gwa', 'slit_frame').inputs == ('name', 'alpha', 'beta', 'gamma') assert w.get_transform('gwa', 'slit_frame').outputs == ('name', 'x_slit', 'y_slit', 'lam') - assert w.get_transform('slit_frame', 'slicer').inputs == ('x_slit', 'y_slit', 'lam') - assert w.get_transform('slit_frame', 'slicer').outputs == ('x_slicer', 'y_slicer', 'lam') - assert w.get_transform('slicer', 'msa_frame').inputs == ('x_slicer', 'y_slicer', 'lam') + assert w.get_transform('slit_frame', 'slicer').inputs == ('name', 'x_slit', 'y_slit') + assert w.get_transform('slit_frame', 'slicer').outputs == ('x_msa', 'y_msa') + assert w.get_transform('slicer', 'msa_frame').inputs == ('x_msa', 'y_msa') assert w.get_transform('slicer', 'msa_frame').outputs == ('x_msa', 'y_msa', 'lam') assert w.get_transform('msa_frame', 'oteip').inputs == ('x_msa', 'y_msa', 'lam') assert w.get_transform('msa_frame', 'oteip').outputs == ('xan', 'yan', 'lam') From 6d077e66a4aee040cd8f299ac4c2236874fd790b Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Thu, 30 May 2024 16:10:24 -0400 Subject: [PATCH 4/4] first draft of fully labeled inputs and outputs --- jwst/assign_wcs/nirspec.py | 6 ++-- jwst/assign_wcs/tests/test_nircam.py | 4 +-- jwst/assign_wcs/tests/test_nirspec.py | 43 +++++++++++---------------- 3 files changed, 22 insertions(+), 31 deletions(-) diff --git a/jwst/assign_wcs/nirspec.py b/jwst/assign_wcs/nirspec.py index ed2fc6bdf9..3d4a760d26 100644 --- a/jwst/assign_wcs/nirspec.py +++ b/jwst/assign_wcs/nirspec.py @@ -283,7 +283,7 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): # in the whole pipeline) to microns (which is the expected output) # # "detector", "gwa", "slit_frame", "msa_frame", "oteip", "v2v3", "world" - breakpoint() + pipeline = [(det, dms2detector), (sca, det2gwa), (gwa, gwa2slit), @@ -854,7 +854,7 @@ def ifuslit_to_slicer(slits, reference_files): msa_transform = slitdata_model | ifuslicer_model msa_transform.name = "ifuslit_to_slicer" msa_transform.inputs = ('x_slit', 'y_slit') - msa_transform.outputs = ('x_slicer', 'y_slicer') + msa_transform.outputs = ('x_msa', 'y_msa') models.append(msa_transform) ifuslicer.close() @@ -878,7 +878,7 @@ def slicer_to_msa(reference_files): ifufore2fore_mapping.inverse = Mapping((0, 1, 2, 2)) ifu_fore_transform = slicer2fore_mapping | ifufore & Identity(1) ifu_fore_transform.name = "slicer_to_msa" - ifu_fore_transform.inputs = ('x_slicer', 'y_slicer', 'lam') + ifu_fore_transform.inputs = ('x_msa', 'y_msa', 'lam') ifu_fore_transform.outputs = ('x_msa', 'y_msa', 'lam') return ifu_fore_transform diff --git a/jwst/assign_wcs/tests/test_nircam.py b/jwst/assign_wcs/tests/test_nircam.py index daf7bacb09..07b63ea52c 100644 --- a/jwst/assign_wcs/tests/test_nircam.py +++ b/jwst/assign_wcs/tests/test_nircam.py @@ -258,9 +258,9 @@ def test_transform_metadata_imaging(create_imaging_wcs): @pytest.mark.parametrize('exptype', ['tso', 'wfss']) -def test_transform_metadata_grism(exptype): +def test_transform_metadata_grism(exptype, create_tso_wcs): if exptype == "tso": - wcsobj = create_tso_wcs() + wcsobj = create_tso_wcs assert wcsobj.get_transform("grism_detector", "detector").inputs == ('x_grism', 'y_grism', 'order') elif exptype == "wfss": wcsobj = create_wfss_wcs('GRISMR') diff --git a/jwst/assign_wcs/tests/test_nirspec.py b/jwst/assign_wcs/tests/test_nirspec.py index cca9c73466..d97071ac2c 100644 --- a/jwst/assign_wcs/tests/test_nirspec.py +++ b/jwst/assign_wcs/tests/test_nirspec.py @@ -5,6 +5,7 @@ from math import cos, sin import os.path +import asdf import pytest import numpy as np from numpy.testing import assert_allclose @@ -785,30 +786,8 @@ def test_functional_ifu_grating(wcs_ifu_grating): assert_allclose(v2, ins_tab['xV2V3']) assert_allclose(v3, ins_tab['yV2V3']) - # test transform metadata for slit_wcs - # available frames ['detector', 'sca', 'gwa', 'slit_frame', 'slicer', 'msa_frame', 'oteip', 'v2v3', 'v2v3vacorr', 'world'] - assert slit_wcs.get_transform('detector', 'sca').inputs == ('x_detector', 'y_detector') - assert slit_wcs.get_transform('detector', 'sca').outputs == ('x_sca', 'y_sca') - assert slit_wcs.get_transform('sca', 'gwa').inputs == ('x_sca', 'y_sca') - assert slit_wcs.get_transform('sca', 'gwa').outputs == ('alpha', 'beta', 'gamma') - assert slit_wcs.get_transform('gwa', 'slit_frame').inputs == ('alpha', 'beta', 'gamma') - assert slit_wcs.get_transform('gwa', 'slit_frame').outputs == ('x_slit', 'y_slit', 'lam') - assert slit_wcs.get_transform('slit_frame', 'slicer').inputs == ('name', 'x_slit', 'y_slit') - assert slit_wcs.get_transform('slit_frame', 'slicer').outputs == ('name', 'x_slicer', 'y_slicer') - assert slit_wcs.get_transform('slicer', 'msa_frame').inputs == ('name', 'x_slicer', 'y_slicer') - assert slit_wcs.get_transform('slicer', 'msa_frame').outputs == ('x_msa', 'y_msa', 'lam') - assert slit_wcs.get_transform('msa_frame', 'oteip').inputs == ('x_msa', 'y_msa', 'lam') - assert slit_wcs.get_transform('msa_frame', 'oteip').outputs == ('xan', 'yan', 'lam') - assert slit_wcs.get_transform('oteip', 'v2v3').inputs == ('xan', 'yan', 'lam') - assert slit_wcs.get_transform('oteip', 'v2v3').outputs == ("v2", "v3", 'lam') - assert slit_wcs.get_transform('v2v3', 'v2v3vacorr').inputs == ("v2", "v3", 'lam') - assert slit_wcs.get_transform('v2v3', 'v2v3vacorr').outputs == ("v2", "v3", 'lam') - assert slit_wcs.get_transform('v2v3vacorr', 'world').inputs == ("v2", "v3", 'lam') - assert slit_wcs.get_transform('v2v3vacorr', 'world').outputs == ("ra", "dec", 'lam') - - - -def test_functional_ifu_prism(): + +def test_functional_ifu_prism(tmp_cwd): """Compare Nirspec instrument model with IDT model for IFU prism.""" # setup test model_file = 'ifu_prism_functional_ESA_v1_20180619.txt' @@ -933,7 +912,7 @@ def test_functional_ifu_prism(): assert w.get_transform('gwa', 'slit_frame').outputs == ('name', 'x_slit', 'y_slit', 'lam') assert w.get_transform('slit_frame', 'slicer').inputs == ('name', 'x_slit', 'y_slit') assert w.get_transform('slit_frame', 'slicer').outputs == ('x_msa', 'y_msa') - assert w.get_transform('slicer', 'msa_frame').inputs == ('x_msa', 'y_msa') + assert w.get_transform('slicer', 'msa_frame').inputs == ('x_msa', 'y_msa', 'lam') assert w.get_transform('slicer', 'msa_frame').outputs == ('x_msa', 'y_msa', 'lam') assert w.get_transform('msa_frame', 'oteip').inputs == ('x_msa', 'y_msa', 'lam') assert w.get_transform('msa_frame', 'oteip').outputs == ('xan', 'yan', 'lam') @@ -944,7 +923,19 @@ def test_functional_ifu_prism(): assert w.get_transform('v2v3vacorr', 'world').inputs == ("v2", "v3", 'lam') assert w.get_transform('v2v3vacorr', 'world').outputs == ("ra", "dec", 'lam') - # can they be serialized? when this is done are the names restored? + + # test these can be serialized and that names are restored when this is done + outfile = 'w_asdf.asdf' + tree = {"wcs": w} + af = asdf.AsdfFile(tree) + af.write_to(outfile) + with asdf.open(outfile) as af: + w = af.tree['wcs'] + print(w) + assert True + assert w.get_transform('detector', 'sca').inputs == ('x_detector', 'y_detector') + assert w.get_transform('detector', 'sca').outputs == ('x_sca', 'y_sca') + assert w.get_transform('detector', 'sca').name == "dms_to_sca" def test_ifu_bbox():