Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 48 additions & 18 deletions python/lsst/ip/diffim/getTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import lsst.geom as geom
import lsst.afw.geom as afwGeom
import lsst.afw.table as afwTable
from lsst.afw.math._warper import computeWarpedBBox
import lsst.afw.math as afwMath
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
Expand Down Expand Up @@ -277,25 +278,39 @@ def run(self, *, coaddExposureHandles, bbox, wcs, dataIds, physical_filter):
for tract in coaddExposureHandles:
maskedImages, catalog, totalBox = self._makeExposureCatalog(coaddExposureHandles[tract],
dataIds[tract])
warpedBox = computeWarpedBBox(catalog[0].wcs, bbox, wcs)
warpedBox.grow(5) # to ensure we catch all relevant input pixels
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of computeWarpedBBox is that that already errs on the side of a larger box, so growing it here is probably unnecessary. That said, an extra 5 pixels shouldn't make much of a difference, and it will still be fewer pixels than before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that without growing the box by a few pixels, the corners in the tests weren't correct, so I decided to be a bit generous here.

# Combine images from individual patches together.
unwarped = self._merge(maskedImages, totalBox, catalog[0].wcs)
unwarped, count, included = self._merge(maskedImages, warpedBox, catalog[0].wcs)
# Delete `maskedImages` after combining into one large image to reduce peak memory use
del maskedImages
potentialInput = self.warper.warpExposure(wcs, unwarped, destBBox=bbox)
if count == 0:
self.log.info("No valid pixels from coadd patches in tract %s; not including in output.",
tract)
continue
warpedBox.clip(totalBox)
potentialInput = self.warper.warpExposure(wcs, unwarped.subset(warpedBox), destBBox=bbox)

# Delete the single large `unwarped` image after warping to reduce peak memory use
del unwarped
if not np.any(np.isfinite(potentialInput.image.array)):
self.log.info("No overlap from coadds in tract %s; not including in output.", tract)
if np.all(potentialInput.mask.array & potentialInput.mask.getPlaneBitMask("NO_DATA")):
self.log.info("No overlap from coadd patches in tract %s; not including in output.", tract)
continue

catalogs.append(catalog)
warped[tract] = potentialInput
warped[tract].setWcs(wcs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this line unnecessary, because it's set in _merge?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the merged result has the wcs when it is created.

# Trim the exposure catalog to just the patches that were used.
tempCatalog = afwTable.ExposureCatalog(self.schema)
tempCatalog.reserve(len(included))
for i in included:
tempCatalog.append(catalog[i])
catalogs.append(tempCatalog)
warped[tract] = potentialInput.maskedImage

if len(warped) == 0:
raise pipeBase.NoWorkFound("No patches found to overlap science exposure.")
template = self._merge([x.maskedImage for x in warped.values()], bbox, wcs)
# At this point, all entries will be valid, so we can ignore included.
template, count, _ = self._merge(warped, bbox, wcs)
if count == 0:
raise pipeBase.NoWorkFound("No valid pixels in warped template.")

# Make a single catalog containing all the inputs that were accepted.
catalog = afwTable.ExposureCatalog(self.schema)
Expand Down Expand Up @@ -362,7 +377,7 @@ def _makeExposureCatalog(self, exposureRefs, dataIds):

Returns
-------
images : `list` [`lsst.afw.image.MaskedImage`]
images : `dict` [`lsst.afw.image.MaskedImage`]
MaskedImages of each of the input exposures, for warping.
catalog : `lsst.afw.table.ExposureCatalog`
Catalog of metadata for each exposure
Expand All @@ -372,11 +387,11 @@ def _makeExposureCatalog(self, exposureRefs, dataIds):
catalog = afwTable.ExposureCatalog(self.schema)
catalog.reserve(len(exposureRefs))
exposures = (exposureRef.get() for exposureRef in exposureRefs)
images = []
images = {}
totalBox = geom.Box2I()

for coadd, dataId in zip(exposures, dataIds):
images.append(coadd.maskedImage)
images[dataId] = coadd.maskedImage
bbox = coadd.getBBox()
totalBox = totalBox.expandedTo(bbox)
record = catalog.addNew()
Expand All @@ -393,15 +408,15 @@ def _makeExposureCatalog(self, exposureRefs, dataIds):

return images, catalog, totalBox

@staticmethod
def _merge(maskedImages, bbox, wcs):
def _merge(self, maskedImages, bbox, wcs):
"""Merge the images that came from one tract into one larger image,
ignoring NaN pixels and non-finite variance pixels from individual
exposures.

Parameters
----------
maskedImages : `list` [`lsst.afw.image.MaskedImage`]
maskedImages : `dict` [`lsst.afw.image.MaskedImage` or
`lsst.afw.image.Exposure`]
Images to be merged into one larger bounding box.
bbox : `lsst.geom.Box2I`
Bounding box defining the image to merge into.
Expand All @@ -413,10 +428,24 @@ def _merge(maskedImages, bbox, wcs):
merged : `lsst.afw.image.MaskedImage`
Merged image with all of the inputs at their respective bbox
positions.
count : `int`
Count of the number of good pixels (those with positive weights)
in the merged image.
included : `list` [`int`]
List of indexes of patches that were included in the merged
result, to be used to trim the exposure catalog.
"""
merged = afwImage.ExposureF(bbox, wcs)
weights = afwImage.ImageF(bbox)
for maskedImage in maskedImages:
included = [] # which patches were included in the result
for i, (dataId, maskedImage) in enumerate(maskedImages.items()):
# Only merge into the trimmed box, to save memory
clippedBox = geom.Box2I(maskedImage.getBBox())
clippedBox.clip(bbox)
if clippedBox.area == 0:
self.log.debug("%s does not overlap template region.", dataId)
continue # nothing in this image overlaps the output
maskedImage = maskedImage.subset(clippedBox)
# Catch both zero-value and NaN variance plane pixels
good = (maskedImage.variance.array > 0) & (np.isfinite(maskedImage.variance.array))
weight = maskedImage.variance.array[good]**(-0.5)
Expand All @@ -432,10 +461,11 @@ def _merge(maskedImages, bbox, wcs):
# `weight` are the exact values we want to scale by.
maskedImage.image.array[good] *= weight
maskedImage.variance.array[good] *= weight
weights[maskedImage.getBBox()].array[good] += weight
weights[clippedBox].array[good] += weight
# Free memory before creating new large arrays
del weight
merged.maskedImage[maskedImage.getBBox()] += maskedImage
merged.maskedImage[clippedBox] += maskedImage
included.append(i)

good = weights.array > 0

Expand All @@ -448,7 +478,7 @@ def _merge(maskedImages, bbox, wcs):

merged.mask.array[~good] |= merged.mask.getPlaneBitMask("NO_DATA")

return merged
return merged, good.sum(), included

def _makePsf(self, template, catalog, wcs):
"""Return a PSF containing the PSF at each of the input regions.
Expand Down
28 changes: 18 additions & 10 deletions tests/test_getTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import lsst.afw.geom
import lsst.afw.image
import lsst.afw.math
from lsst.daf.butler import DataCoordinate, DimensionUniverse
import lsst.geom
import lsst.ip.diffim
import lsst.meas.algorithms
Expand Down Expand Up @@ -168,11 +169,14 @@ def _makePatches(self, tract):
)
)
self.patches[tract.tract_id].append(dataRef)
self.dataIds[tract.tract_id].append({"tract": tract.tract_id,
"patch": patchId,
"band": "a"})

def _checkMetadata(self, template, config, box, wcs, nInputs):
dataCoordinate = DataCoordinate.standardize({"tract": tract.tract_id,
"patch": patchId,
"band": "a",
"skymap": "skymap"},
universe=DimensionUniverse())
self.dataIds[tract.tract_id].append(dataCoordinate)

def _checkMetadata(self, template, config, box, wcs, nPsfs):
"""Check that the various metadata components were set correctly.
"""
expectedBox = lsst.geom.Box2I(box)
Expand All @@ -186,7 +190,7 @@ def _checkMetadata(self, template, config, box, wcs, nInputs):
self.assertEqual(template.getXY0(), expectedBox.getMin())
self.assertEqual(template.filter.bandLabel, "a")
self.assertEqual(template.filter.physicalLabel, "a_test")
self.assertEqual(template.psf.getComponentCount(), nInputs)
self.assertEqual(template.psf.getComponentCount(), nPsfs)

def _checkPixels(self, template, config, box):
"""Check that the pixel values in the template are close to the
Expand Down Expand Up @@ -246,7 +250,7 @@ def testRunOneTractMultipleInputs(self):
physical_filter="a_test")

# All 4 patches from two tracts are included in this template.
self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 8)
self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 6)
self._checkPixels(result.template, task.config, box)

def testRunTwoTracts(self):
Expand All @@ -262,7 +266,7 @@ def testRunTwoTracts(self):
physical_filter="a_test")

# All 4 patches from all 4 tracts are included in this template
self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 16)
self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 9)
self._checkPixels(result.template, task.config, box)

def testRunNoTemplate(self):
Expand Down Expand Up @@ -306,7 +310,9 @@ def testMissingPatches(self):
nInput=[8, 16],
)
def testNanInputs(self, box=None, nInput=None):
"""Test that the template has finite values when some of the input pixels have NaN as variance."""
"""Test that the template has finite values when some of the input
pixels have NaN as variance.
"""
for tract, patchRefs in self.patches.items():
for patchRef in patchRefs:
patchCoadd = patchRef.get()
Expand All @@ -322,7 +328,9 @@ def testNanInputs(self, box=None, nInput=None):
wcs=self.exposure.wcs,
dataIds=self.dataIds,
physical_filter="a_test")
self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 16)
if debug:
_showTemplate(box, result.template)
self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 9)
# We just check that the pixel values are all finite. We cannot check that pixel values
# in the template are closer to the original anymore.
self.assertTrue(np.isfinite(result.template.image.array).all())
Expand Down