Skip to content

Commit

Permalink
CADC-10810 - integration with server-side and TAOSII Temporal WCS cha…
Browse files Browse the repository at this point in the history
…nges.
  • Loading branch information
SharonGoliath committed Jul 19, 2023
1 parent 17e511f commit acdfc72
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 110 deletions.
209 changes: 103 additions & 106 deletions caom2utils/caom2utils/caom2blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
from logging.handlers import TimedRotatingFileHandler

import math
from astropy.wcs import utils, Wcsprm, WCS
from astropy.wcs import SingularMatrixError, utils, Wcsprm, WCS
from astropy.io import fits
from astropy.time import Time
from cadcutils import version
Expand Down Expand Up @@ -2849,7 +2849,6 @@ def _get_telescope(self, current):
name = self._get_from_list(
'Observation.telescope.name', index=0,
current=None if current is None else current.name)
logging.error(self._get_from_list('Observation.telescope.geoLocationX', index=0))
geo_x = _to_float(
self._get_from_list(
'Observation.telescope.geoLocationX', index=0,
Expand Down Expand Up @@ -3953,19 +3952,18 @@ def augment_custom(self, chunk):
self.logger.debug('No WCS Custom axis.function')
return

chunk.custom_axis = custom_axis_index + 1

naxis = CoordAxis1D(self._get_axis(custom_axis_index))
if self.wcs.has_cd():
delta = self.wcs.cd[custom_axis_index][
custom_axis_index]
else:
delta = self.wcs.cdelt[custom_axis_index]
naxis.function = CoordFunction1D(custom_axis_length, delta, self._get_ref_coord(custom_axis_index))
if not chunk.custom:
chunk.custom = CustomWCS(naxis)
else:
chunk.custom.axis = naxis
if custom_axis_length > 0:
chunk.custom_axis = custom_axis_index + 1
naxis = CoordAxis1D(self._get_axis(custom_axis_index))
if self.wcs.has_cd():
delta = self.wcs.cd[custom_axis_index][custom_axis_index]
else:
delta = self.wcs.cdelt[custom_axis_index]
naxis.function = CoordFunction1D(custom_axis_length, delta, self._get_ref_coord(custom_axis_index))
if not chunk.custom:
chunk.custom = CustomWCS(naxis)
else:
chunk.custom.axis = naxis

self.logger.debug('End Custom WCS augmentation.')

Expand All @@ -3990,32 +3988,33 @@ def augment_energy(self, chunk):
self.logger.debug('No WCS Energy axis.function')
return

chunk.energy_axis = energy_axis_index + 1
naxis = CoordAxis1D(self._get_axis(energy_axis_index))
naxis.error = self._get_coord_error(energy_axis_index)
if self.wcs.has_cd():
delta = self.wcs.cd[energy_axis_index][energy_axis_index]
else:
delta = self.wcs.cdelt[energy_axis_index]
naxis.function = CoordFunction1D(energy_axis_length, delta, self._get_ref_coord(energy_axis_index))
if energy_axis_length > 0:
chunk.energy_axis = energy_axis_index + 1
naxis = CoordAxis1D(self._get_axis(energy_axis_index))
naxis.error = self._get_coord_error(energy_axis_index)
if self.wcs.has_cd():
delta = self.wcs.cd[energy_axis_index][energy_axis_index]
else:
delta = self.wcs.cdelt[energy_axis_index]
naxis.function = CoordFunction1D(energy_axis_length, delta, self._get_ref_coord(energy_axis_index))

specsys = _to_str(self.wcs.specsys)
if not chunk.energy:
chunk.energy = SpectralWCS(naxis, specsys)
else:
chunk.energy.axis = naxis
chunk.energy.specsys = specsys

chunk.energy.ssysobs = _to_str(self._sanitize(self.wcs.ssysobs))
# wcs returns 0.0 by default
if self._sanitize(self.wcs.restfrq) != 0:
chunk.energy.restfrq = self._sanitize(self.wcs.restfrq)
if self._sanitize(self.wcs.restwav) != 0:
chunk.energy.restwav = self._sanitize(self.wcs.restwav)
chunk.energy.velosys = self._sanitize(self.wcs.velosys)
chunk.energy.zsource = self._sanitize(self.wcs.zsource)
chunk.energy.ssyssrc = _to_str(self._sanitize(self.wcs.ssyssrc))
chunk.energy.velang = self._sanitize(self.wcs.velangl)
specsys = _to_str(self.wcs.specsys)
if not chunk.energy:
chunk.energy = SpectralWCS(naxis, specsys)
else:
chunk.energy.axis = naxis
chunk.energy.specsys = specsys

chunk.energy.ssysobs = _to_str(self._sanitize(self.wcs.ssysobs))
# wcs returns 0.0 by default
if self._sanitize(self.wcs.restfrq) != 0:
chunk.energy.restfrq = self._sanitize(self.wcs.restfrq)
if self._sanitize(self.wcs.restwav) != 0:
chunk.energy.restwav = self._sanitize(self.wcs.restwav)
chunk.energy.velosys = self._sanitize(self.wcs.velosys)
chunk.energy.zsource = self._sanitize(self.wcs.zsource)
chunk.energy.ssyssrc = _to_str(self._sanitize(self.wcs.ssyssrc))
chunk.energy.velang = self._sanitize(self.wcs.velangl)
self.logger.debug('End Energy WCS augmentation.')

def augment_position(self, chunk):
Expand Down Expand Up @@ -4076,29 +4075,29 @@ def augment_temporal(self, chunk):
# set chunk.time
self.logger.debug('Begin temporal axis augmentation.')

aug_naxis = self._get_axis(time_axis_index)
aug_error = self._get_coord_error(time_axis_index)
aug_ref_coord = self._get_ref_coord(time_axis_index)
if self.wcs.has_cd():
delta = self.wcs.cd[time_axis_index][time_axis_index]
else:
delta = self.wcs.cdelt[time_axis_index]

try:
axis_length = self._get_axis_length(time_axis_index + 1)
except ValueError:
self.logger.debug('No WCS Temporal axis.function')
return

if aug_ref_coord is not None and axis_length is not None:
aug_function = CoordFunction1D(axis_length, delta, aug_ref_coord)
naxis = CoordAxis1D(aug_naxis, aug_error, None, None, aug_function)
if not chunk.time:
chunk.time = TemporalWCS(naxis)
if axis_length > 0:
aug_naxis = self._get_axis(time_axis_index)
aug_error = self._get_coord_error(time_axis_index)
aug_ref_coord = self._get_ref_coord(time_axis_index)
if self.wcs.has_cd():
delta = self.wcs.cd[time_axis_index][time_axis_index]
else:
chunk.time.axis = naxis
delta = self.wcs.cdelt[time_axis_index]
if aug_ref_coord is not None and axis_length is not None:
aug_function = CoordFunction1D(axis_length, delta, aug_ref_coord)
naxis = CoordAxis1D(aug_naxis, aug_error, None, None, aug_function)
if not chunk.time:
chunk.time = TemporalWCS(naxis)
else:
chunk.time.axis = naxis

self._finish_chunk_time(chunk)
self._finish_chunk_time(chunk)
self.logger.debug('End TemporalWCS augmentation.')

def augment_polarization(self, chunk):
Expand All @@ -4116,26 +4115,25 @@ def augment_polarization(self, chunk):
self.logger.debug('No WCS Polarization info')
return

chunk.polarization_axis = polarization_axis_index + 1

naxis = CoordAxis1D(self._get_axis(polarization_axis_index))
if self.wcs.has_cd():
delta = self.wcs.cd[polarization_axis_index][
polarization_axis_index]
else:
delta = self.wcs.cdelt[polarization_axis_index]

try:
axis_length = self._get_axis_length(polarization_axis_index + 1)
except ValueError:
self.logger.debug('No WCS Polarization axis.function')
return

naxis.function = CoordFunction1D(axis_length, delta, self._get_ref_coord(polarization_axis_index))
if not chunk.polarization:
chunk.polarization = PolarizationWCS(naxis)
else:
chunk.polarization.axis = naxis
if axis_length > 0:
chunk.polarization_axis = polarization_axis_index + 1

naxis = CoordAxis1D(self._get_axis(polarization_axis_index))
if self.wcs.has_cd():
delta = self.wcs.cd[polarization_axis_index][polarization_axis_index]
else:
delta = self.wcs.cdelt[polarization_axis_index]
naxis.function = CoordFunction1D(axis_length, delta, self._get_ref_coord(polarization_axis_index))
if not chunk.polarization:
chunk.polarization = PolarizationWCS(naxis)
else:
chunk.polarization.axis = naxis

self.logger.debug('End Polarization WCS augmentation.')

Expand All @@ -4160,10 +4158,7 @@ def augment_observable(self, chunk):
self.logger.debug('End Observable WCS augmentation.')

def _finish_chunk_position(self, chunk):
if chunk.position.resolution is None:
# JJK 30-01-23
# In a spatial data chunk the resolution is 2 times the pixel size. We can get the pixel size from the wcs
chunk.position.resolution = utils.proj_plane_pixel_scales(self.wcs)[0]
pass

def _finish_chunk_time(self, chunk):
raise NotImplementedError
Expand Down Expand Up @@ -4241,11 +4236,13 @@ def _get_dimension(self, xindex, yindex):
except ValueError:
self.logger.debug('No WCS Energy axis.function')
return None
aug_dim1 = _to_int(xindex_axis_length)
aug_dim2 = _to_int(yindex_axis_length)
if aug_dim1 and aug_dim2:
aug_dimension = Dimension2D(aug_dim1, aug_dim2)
self.logger.debug('End 2D dimension augmentation.')

if xindex_axis_length > 0 and yindex_axis_length > 0:
aug_dim1 = _to_int(xindex_axis_length)
aug_dim2 = _to_int(yindex_axis_length)
if aug_dim1 and aug_dim2:
aug_dimension = Dimension2D(aug_dim1, aug_dim2)
self.logger.debug('End 2D dimension augmentation.')
return aug_dimension

def _get_position_axis(self):
Expand Down Expand Up @@ -4445,8 +4442,7 @@ def _get_axis_index(self, keywords):

def _get_axis_length(self, for_axis):
if self._wcs.array_shape is None:
# TODO I think this is wrong
return 1
return 0
else:
if len(self._wcs.array_shape) == 1:
result = self._wcs.array_shape[0]
Expand Down Expand Up @@ -4474,6 +4470,14 @@ def assign_sanitize(self, assignee, index, key, sanitize=True):
if x is not None and not ObsBlueprint.needs_lookup(x):
assignee[index] = x

def _assign_cd(self, key, cd, count):
x = self._blueprint._get(key, self._extension)
if x is not None:
if ObsBlueprint.needs_lookup(x):
cd[count][count] = 1.0
else:
cd[count][count] = x

def _set_wcs(self):
self._wcs = WCS(naxis=self._blueprint.get_configed_axes_count())
array_shape = [0] * self._blueprint.get_configed_axes_count()
Expand All @@ -4491,15 +4495,12 @@ def _set_wcs(self):
self._axes['dec'][1] = True
self._axes['ra'][0] = count
self._axes['dec'][0] = count + 1
# temp = [0] * self._blueprint.get_configed_axes_count()
# cd = [temp.copy()
# for _ in range(self._blueprint.get_configed_axes_count())]
self.assign_sanitize(ctype, count, 'Chunk.position.axis.axis1.ctype')
self.assign_sanitize(ctype, count + 1, 'Chunk.position.axis.axis2.ctype')
self.assign_sanitize(cunit, count, 'Chunk.position.axis.axis1.cunit')
self.assign_sanitize(cunit, count + 1, 'Chunk.position.axis.axis2.cunit')
array_shape[count] = self._blueprint._get('Chunk.position.axis.function.dimension.naxis1')
array_shape[count + 1] = self._blueprint._get('Chunk.position.axis.function.dimension.naxis2')
self.assign_sanitize(array_shape, count, 'Chunk.position.axis.function.dimension.naxis1')
self.assign_sanitize(array_shape, count + 1, 'Chunk.position.axis.function.dimension.naxis2')
self.assign_sanitize(crpix, count, 'Chunk.position.axis.function.refCoord.coord1.pix')
self.assign_sanitize(crpix, count + 1, 'Chunk.position.axis.function.refCoord.coord2.pix')
self.assign_sanitize(crval, count, 'Chunk.position.axis.function.refCoord.coord1.val')
Expand Down Expand Up @@ -4530,39 +4531,35 @@ def _set_wcs(self):
self._axes['time'][0] = count
self.assign_sanitize(ctype, count, 'Chunk.time.axis.axis.ctype', False)
self.assign_sanitize(cunit, count, 'Chunk.time.axis.axis.cunit', False)
array_shape[count] = self._blueprint._get(
'Chunk.time.axis.function.naxis', self._extension)
self.assign_sanitize(array_shape, count, 'Chunk.time.axis.function.naxis', False)
self.assign_sanitize(crpix, count, 'Chunk.time.axis.function.refCoord.pix', False)
self.assign_sanitize(crval, count, 'Chunk.time.axis.function.refCoord.val', False)
self.assign_sanitize(crder, count, 'Chunk.time.axis.error.rnder')
self.assign_sanitize(csyer, count, 'Chunk.time.axis.error.syser')
cd[count][count] = 1.0
self._assign_cd('Chunk.time.axis.function.delta', cd, count)
count += 1
if self._blueprint._energy_axis_configed:
self._axes['energy'][1] = True
self._axes['energy'][0] = count
self.assign_sanitize(ctype, count, 'Chunk.energy.axis.axis.ctype', False)
self.assign_sanitize(cunit, count, 'Chunk.energy.axis.axis.cunit', False)
array_shape[count] = self._blueprint._get(
'Chunk.energy.axis.function.naxis', self._extension)
self.assign_sanitize(array_shape, count, 'Chunk.energy.axis.function.naxis', False)
self.assign_sanitize(crpix, count, 'Chunk.energy.axis.function.refCoord.pix', False)
self.assign_sanitize(crval, count, 'Chunk.energy.axis.function.refCoord.val', False)
self.assign_sanitize(crder, count, 'Chunk.energy.axis.error.rnder')
self.assign_sanitize(csyer, count, 'Chunk.energy.axis.error.syser')
cd[count][count] = 1.0
self._assign_cd('Chunk.energy.axis.function.delta', cd, count)
count += 1
if self._blueprint._polarization_axis_configed:
self._axes['polarization'][1] = True
self._axes['polarization'][0] = count
self.assign_sanitize(ctype, count, 'Chunk.polarization.axis.axis.ctype', False)
self.assign_sanitize(cunit, count, 'Chunk.polarization.axis.axis.cunit', False)
array_shape[count] = self._blueprint._get(
'Chunk.polarization.axis.function.naxis', self._extension)
self.assign_sanitize(array_shape, count, 'Chunk.polarization.axis.function.naxis', False)
self.assign_sanitize(crpix, count, 'Chunk.polarization.axis.function.refCoord.pix', False)
self.assign_sanitize(crval, count, 'Chunk.polarization.axis.function.refCoord.val', False)
cd[count][count] = 1.0
self._assign_cd('Chunk.polarization.axis.function.delta', cd, count)
count += 1
# TODO - where's the delta?
if self._blueprint._obs_axis_configed:
self._axes['observable'][1] = True
self._axes['observable'][0] = count
Expand All @@ -4578,27 +4575,21 @@ def _set_wcs(self):
self._axes['custom'][0] = count
self.assign_sanitize(ctype, count, 'Chunk.custom.axis.axis.ctype', False)
self.assign_sanitize(cunit, count, 'Chunk.custom.axis.axis.cunit', False)
array_shape[count] = self._blueprint._get(
'Chunk.custom.axis.function.naxis', self._extension)
# TODO delta
self.assign_sanitize(array_shape, count, 'Chunk.custom.axis.function.naxis', False)
self.assign_sanitize(crpix, count, 'Chunk.custom.axis.function.refCoord.pix', False)
self.assign_sanitize(crval, count, 'Chunk.custom.axis.function.refCoord.val', False)
cd[count][count] = 1.0
self._assign_cd('Chunk.custom.axis.function.delta', cd, count)
count += 1

if not all(val == 0 for val in array_shape):
self._wcs.array_shape = array_shape
if not all(val == 0 for val in cunit):
# logging.error(cunit)
self._wcs.wcs.cunit = cunit
if not all(val == 0 for val in ctype):
# logging.error(ctype)
self._wcs.wcs.ctype = ctype
if not all(val == 0 for val in crpix):
# logging.error(crpix)
self._wcs.wcs.crpix = crpix
if not all(val == 0 for val in crval):
# logging.error(crval)
self._wcs.wcs.crval = crval
if not all(val == 0 for val in crder):
self._wcs.wcs.crder = crder
Expand All @@ -4619,8 +4610,14 @@ def _finish_chunk_observable(self, chunk):

def _finish_chunk_position(self, chunk):
if chunk.position.resolution is None:
temp = utils.proj_plane_pixel_scales(self._wcs)
chunk.position.resolution = temp[0]
try:
# JJK 30-01-23
# In a spatial data chunk the resolution is 2 times the pixel size. We can get the pixel size from the wcs
temp = utils.proj_plane_pixel_scales(self._wcs)
chunk.position.resolution = temp[0]
except SingularMatrixError as e:
# cannot calculate position.resolution, ignore and continue on
self.logger.warning(f'Not calculating resolution due to {e}')

def _finish_chunk_time(self, chunk):
if not math.isnan(self._wcs.wcs.xposure):
Expand All @@ -4629,7 +4626,7 @@ def _finish_chunk_time(self, chunk):
chunk.time.timesys = self._wcs.wcs.timesys
if self._wcs.wcs.trefpos is not None and self._wcs.wcs.trefpos != '':
chunk.time.trefpos = self._wcs.wcs.trefpos
if self._wcs.wcs.mjdref is not None and self._wcs.wcs.mjdref[0] != '':
if self._wcs.wcs.mjdref is not None and self._wcs.wcs.mjdref[0] != '' and self._wcs.wcs.mjdref[0] != 0.0:
# the astropy value is an array of length 2, use the first value
chunk.time.mjdref = self._wcs.wcs.mjdref[0]

Expand Down
Loading

0 comments on commit acdfc72

Please sign in to comment.