Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions python/lsst/ip/diffim/getTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):

results = self.getExposures(coaddExposures, bbox, skymap, wcs)
physical_filter = butlerQC.quantum.dataId["physical_filter"]
outputs = self.run(coaddExposures=results.coaddExposures,
outputs = self.run(coaddExposureHandles=results.coaddExposures,
bbox=bbox,
wcs=wcs,
dataIds=results.dataIds,
Expand Down Expand Up @@ -184,7 +184,8 @@ def getExposures(self, coaddExposureHandles, bbox, skymap, wcs):
``coaddExposures``
Dict of coadd exposures that overlap the projected bbox,
indexed on tract id
(`dict` [`int`, `list` [`lsst.afw.image.Exposure`] ]).
(`dict` [`int`, `list` [`lsst.daf.butler.DeferredDatasetHandle` of
`lsst.afw.image.Exposure`] ]).
``dataIds``
Dict of data IDs of the coadd exposures that overlap the
projected bbox, indexed on tract id
Expand Down Expand Up @@ -214,7 +215,7 @@ def getExposures(self, coaddExposureHandles, bbox, skymap, wcs):
if patchPolygon.intersection(detectorPolygon):
overlappingArea += patchPolygon.intersectionSingle(detectorPolygon).calculateArea()
self.log.info("Using template input tract=%s, patch=%s", dataId['tract'], dataId['patch'])
coaddExposures[dataId['tract']].append(coaddRef.get())
coaddExposures[dataId['tract']].append(coaddRef)
dataIds[dataId['tract']].append(dataId)

if not overlappingArea:
Expand All @@ -224,7 +225,7 @@ def getExposures(self, coaddExposureHandles, bbox, skymap, wcs):
dataIds=dataIds)

@timeMethod
def run(self, *, coaddExposures, bbox, wcs, dataIds, physical_filter):
def run(self, *, coaddExposureHandles, bbox, wcs, dataIds, physical_filter):
"""Warp coadds from multiple tracts and patches to form a template to
subtract from a science image.

Expand All @@ -237,14 +238,16 @@ def run(self, *, coaddExposures, bbox, wcs, dataIds, physical_filter):

Parameters
----------
coaddExposures : `dict` [`int`, `list` [`lsst.afw.image.Exposure`]]
coaddExposureHandles : `dict` [`int`, `list` of \
[`lsst.daf.butler.DeferredDatasetHandle` of \
`lsst.afw.image.Exposure`]]
Coadds to be mosaicked, indexed on tract id.
bbox : `lsst.geom.Box2I`
Template Bounding box of the detector geometry onto which to
resample the ``coaddExposures``. Modified in-place to include the
resample the ``coaddExposureHandles``. Modified in-place to include the
template border.
wcs : `lsst.afw.geom.SkyWcs`
Template WCS onto which to resample the ``coaddExposures``.
Template WCS onto which to resample the ``coaddExposureHandles``.
dataIds : `dict` [`int`, `list` [`lsst.daf.butler.DataCoordinate`]]
Record of the tract and patch of each coaddExposure, indexed on
tract id.
Expand All @@ -265,21 +268,27 @@ def run(self, *, coaddExposures, bbox, wcs, dataIds, physical_filter):
NoWorkFound
If no coadds are found with sufficient un-masked pixels.
"""
band, photoCalib = self._checkInputs(dataIds, coaddExposures)
band, photoCalib = self._checkInputs(dataIds, coaddExposureHandles)

bbox.grow(self.config.templateBorderSize)

warped = {}
catalogs = []
for tract in coaddExposures:
maskedImages, catalog, totalBox = self._makeExposureCatalog(coaddExposures[tract],
for tract in coaddExposureHandles:
maskedImages, catalog, totalBox = self._makeExposureCatalog(coaddExposureHandles[tract],
dataIds[tract])
# Combine images from individual patches together.
unwarped = self._merge(maskedImages, totalBox, catalog[0].wcs)
# Delete `maskedImages` after combining into one large image to reduce peak memory use
del maskedImages
potentialInput = self.warper.warpExposure(wcs, unwarped, destBBox=bbox)

# Delete the single large `unwarped` image after warping to reduce peak memory use
del unwarped
if not np.any(np.isfinite(potentialInput.image.array)):
self.log.info("No overlap from coadds in tract %s; not including in output.", tract)
continue

catalogs.append(catalog)
warped[tract] = potentialInput
warped[tract].setWcs(wcs)
Expand Down Expand Up @@ -308,7 +317,9 @@ def _checkInputs(dataIds, coaddExposures):
----------
dataIds : `dict` [`int`, `list` [`lsst.daf.butler.DataCoordinate`]]
Record of the tract and patch of each coaddExposure.
coaddExposures : `dict` [`int`, `list` [`lsst.afw.image.Exposure`]]
coaddExposures : `dict` [`int`, `list` of \
[`lsst.daf.butler.DeferredDatasetHandle` of \
`lsst.afw.image.Exposure`]]
Coadds to be mosaicked.

Returns
Expand All @@ -328,19 +339,22 @@ def _checkInputs(dataIds, coaddExposures):
if len(bands) > 1:
raise RuntimeError(f"GetTemplateTask called with multiple bands: {bands}")
band = bands.pop()
photoCalibs = [exposure.photoCalib for exposures in coaddExposures.values() for exposure in exposures]
photoCalibs = [exposure.get(component="photoCalib")
for exposures in coaddExposures.values()
for exposure in exposures]
if not all([photoCalibs[0] == x for x in photoCalibs]):
msg = f"GetTemplateTask called with exposures with different photoCalibs: {photoCalibs}"
raise RuntimeError(msg)
photoCalib = photoCalibs[0]
return band, photoCalib

def _makeExposureCatalog(self, exposures, dataIds):
def _makeExposureCatalog(self, exposureRefs, dataIds):
"""Make an exposure catalog for one tract.

Parameters
----------
exposures : `list` [`lsst.afw.image.Exposuref`]
exposureRefs : `list` of [`lsst.daf.butler.DeferredDatasetHandle` of \
`lsst.afw.image.Exposure`]
Exposures to include in the catalog.
dataIds : `list` [`lsst.daf.butler.DataCoordinate`]
Data ids of each of the included exposures; must have "tract" and
Expand All @@ -356,17 +370,21 @@ def _makeExposureCatalog(self, exposures, dataIds):
The union of the bounding boxes of all the input exposures.
"""
catalog = afwTable.ExposureCatalog(self.schema)
catalog.reserve(len(exposures))
images = [exposure.maskedImage for exposure in exposures]
catalog.reserve(len(exposureRefs))
exposures = (exposureRef.get() for exposureRef in exposureRefs)
images = []
totalBox = geom.Box2I()

for coadd, dataId in zip(exposures, dataIds):
totalBox = totalBox.expandedTo(coadd.getBBox())
images.append(coadd.maskedImage)
bbox = coadd.getBBox()
totalBox = totalBox.expandedTo(bbox)
record = catalog.addNew()
record.setPsf(coadd.psf)
record.setWcs(coadd.wcs)
record.setPhotoCalib(coadd.photoCalib)
record.setBBox(coadd.getBBox())
record.setValidPolygon(afwGeom.Polygon(geom.Box2D(coadd.getBBox()).getCorners()))
record.setBBox(bbox)
record.setValidPolygon(afwGeom.Polygon(geom.Box2D(bbox).getCorners()))
record.set("tract", dataId["tract"])
record.set("patch", dataId["patch"])
# Weight is used by CoaddPsf, but the PSFs from overlapping patches
Expand Down Expand Up @@ -400,9 +418,8 @@ def _merge(maskedImages, bbox, wcs):
weights = afwImage.ImageF(bbox)
for maskedImage in maskedImages:
# Catch both zero-value and NaN variance plane pixels
good = maskedImage.variance.array > 0
weight = afwImage.ImageF(maskedImage.getBBox())
weight.array[good] = maskedImage.variance.array[good]**(-0.5)
good = (maskedImage.variance.array > 0) & (np.isfinite(maskedImage.variance.array))
weight = maskedImage.variance.array[good]**(-0.5)
bad = np.isnan(maskedImage.image.array) | ~good
# Note that modifying the patch MaskedImage in place is fine;
# we're throwing it away at the end anyway.
Expand All @@ -413,21 +430,23 @@ def _merge(maskedImages, bbox, wcs):
# Cannot use `merged.maskedImage *= weight` because that operator
# multiplies the variance by the weight twice; in this case
# `weight` are the exact values we want to scale by.
maskedImage.image *= weight
maskedImage.variance *= weight
maskedImage.image.array[good] *= weight
maskedImage.variance.array[good] *= weight
weights[maskedImage.getBBox()].array[good] += weight
# Free memory before creating new large arrays
del weight
merged.maskedImage[maskedImage.getBBox()] += maskedImage
weights[maskedImage.getBBox()] += weight

inverseWeights = np.zeros_like(weights.array)
good = weights.array > 0
inverseWeights[good] = 1/weights.array[good]

# Cannot use `merged.maskedImage *= inverseWeights` because that
# Cannot use `merged.maskedImage /= weights` because that
# operator divides the variance by the weight twice; in this case
# `inverseWeights` are the exact values we want to scale by.
merged.image.array *= inverseWeights
merged.variance.array *= inverseWeights
merged.mask.array |= merged.mask.getPlaneBitMask("NO_DATA") * (inverseWeights == 0)
# `weights` are the exact values we want to scale by.
weights = weights.array[good]
merged.image.array[good] /= weights
merged.variance.array[good] /= weights

merged.mask.array[~good] |= merged.mask.getPlaneBitMask("NO_DATA")

return merged

Expand Down
33 changes: 23 additions & 10 deletions tests/test_getTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
import lsst.ip.diffim
import lsst.meas.algorithms
import lsst.meas.base.tests
import lsst.pipe.base as pipeBase
import lsst.skymap
import lsst.utils.tests

from utils import generate_data_id

# Change this to True, `setup display_ds9`, and open ds9 (or use another afw
# display backend) to show the tract/patch layouts on the image.
debug = False
Expand Down Expand Up @@ -155,7 +158,16 @@ def _makePatches(self, tract):
warpedPsf = lsst.meas.algorithms.WarpedPsf(self.exposure.psf, xyTransform)
warped = warper.warpExposure(patch.wcs, self.exposure, destBBox=box)
warped.setPsf(warpedPsf)
self.patches[tract.tract_id].append(warped)
dataRef = pipeBase.InMemoryDatasetHandle(
warped,
storageClass="ExposureF",
copy=True,
dataId=generate_data_id(
tract=tract,
patch=patch,
)
)
self.patches[tract.tract_id].append(dataRef)
self.dataIds[tract.tract_id].append({"tract": tract.tract_id,
"patch": patchId,
"band": "a"})
Expand All @@ -168,7 +180,7 @@ def _checkMetadata(self, template, config, box, wcs, nInputs):
self.assertEqual(template.getBBox(), expectedBox)
# WCS should match our exposure, not any of the coadd tracts.
for tract in self.patches:
self.assertNotEqual(template.wcs, self.patches[tract][0].wcs)
self.assertNotEqual(template.wcs, self.patches[tract][0].get().wcs)
self.assertEqual(template.wcs, self.exposure.wcs)
self.assertEqual(template.photoCalib, self.exposure.photoCalib)
self.assertEqual(template.getXY0(), expectedBox.getMin())
Expand Down Expand Up @@ -209,7 +221,7 @@ def testRunOneTractInput(self):
task = lsst.ip.diffim.GetTemplateTask()
# Restrict to tract 0, since the box fits in just that tract.
# Task modifies the input bbox, so pass a copy.
result = task.run(coaddExposures={0: self.patches[0]},
result = task.run(coaddExposureHandles={0: self.patches[0]},
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds={0: self.dataIds[0]},
Expand All @@ -227,7 +239,7 @@ def testRunOneTractMultipleInputs(self):
box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180))
task = lsst.ip.diffim.GetTemplateTask()
# Task modifies the input bbox, so pass a copy.
result = task.run(coaddExposures=self.patches,
result = task.run(coaddExposureHandles=self.patches,
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds=self.dataIds,
Expand All @@ -243,7 +255,7 @@ def testRunTwoTracts(self):
box = lsst.geom.Box2I(lsst.geom.Point2I(200, 200), lsst.geom.Point2I(600, 600))
task = lsst.ip.diffim.GetTemplateTask()
# Task modifies the input bbox, so pass a copy.
result = task.run(coaddExposures=self.patches,
result = task.run(coaddExposureHandles=self.patches,
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds=self.dataIds,
Expand All @@ -259,7 +271,7 @@ def testRunNoTemplate(self):
box = lsst.geom.Box2I(lsst.geom.Point2I(1200, 1200), lsst.geom.Point2I(1600, 1600))
task = lsst.ip.diffim.GetTemplateTask()
with self.assertRaisesRegex(lsst.pipe.base.NoWorkFound, "No patches found"):
task.run(coaddExposures=self.patches,
task.run(coaddExposureHandles=self.patches,
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds=self.dataIds,
Expand All @@ -276,7 +288,7 @@ def testMissingPatches(self):
box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180))
task = lsst.ip.diffim.GetTemplateTask()
# Task modifies the input bbox, so pass a copy.
result = task.run(coaddExposures=self.patches,
result = task.run(coaddExposureHandles=self.patches,
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds=self.dataIds,
Expand All @@ -295,16 +307,17 @@ def testMissingPatches(self):
)
def testNanInputs(self, box=None, nInput=None):
"""Test that the template has finite values when some of the input pixels have NaN as variance."""
for tract, patchCoadds in self.patches.items():
for patchCoadd in patchCoadds:
for tract, patchRefs in self.patches.items():
for patchRef in patchRefs:
patchCoadd = patchRef.get()
bbox = lsst.geom.Box2I()
bbox.include(lsst.geom.Point2I(patchCoadd.getBBox().getCenter()))
bbox.grow(3)
patchCoadd.variance[bbox].array *= np.nan

box = lsst.geom.Box2I(lsst.geom.Point2I(200, 200), lsst.geom.Point2I(600, 600))
task = lsst.ip.diffim.GetTemplateTask()
result = task.run(coaddExposures=self.patches,
result = task.run(coaddExposureHandles=self.patches,
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds=self.dataIds,
Expand Down
74 changes: 74 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import lsst.afw.image as afwImage
import lsst.afw.math as afwMath
import lsst.afw.table as afwTable
from lsst.daf.butler import DataCoordinate, DimensionUniverse
import lsst.meas.algorithms as measAlg
import lsst.meas.base as measBase
from lsst.meas.algorithms.testUtils import plantSources
Expand Down Expand Up @@ -1147,3 +1148,76 @@ class CustomCoaddPsf(measAlg.CoaddPsf):
"""
def getAveragePosition(self):
return geom.Point2D(-10000, -10000)


def generate_data_id(*,
Copy link
Contributor

@parejkoj parejkoj Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooof, is this really necessary? That's unfortunate. I'd have thought there was an easier way to do this in the middleware, but maybe not if we don't have a real butler?

If you lifted this from elsewhere, is there a place we could put it that both could use instead of having it be copied?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original that I based this on is in lsst.cell_coadds. I certainly don't want to add a dependency on that package here, and my version is modified from that one. This code could be refactored and added to utils or pipe_base, but I would like to do that on a different ticket if possible.

tract: int = 9813,
patch: int = 42,
cell_x: int = 4,
cell_y: int = 2,
band: str = "notR",
) -> DataCoordinate:
"""Generate a DataCoordinate instance to use as data_id.

Modified from ``generate_data_id`` in ``lsst.cell_coadds.test_utils``

Parameters
----------
tract : `int`, optional
Tract ID for the data_id
patch : `int`, optional
Patch ID for the data_id
cell_x : `int`, optional
X index of the cell this patch corresponds to
cell_y : `int`, optional
Y index of the cell this patch corresponds to
band : `str`, optional
Band for the data_id

Returns
-------
data_id : `lsst.daf.butler.DataCoordinate`
An expanded data_id instance.
"""
universe = DimensionUniverse()

instrument = universe["instrument"]
instrument_record = instrument.RecordClass(
name="DummyCam",
class_name="lsst.obs.base.instrument_tests.DummyCam",
)

skymap = universe["skymap"]
skymap_record = skymap.RecordClass(name="test_skymap")

band_element = universe["band"]
band_record = band_element.RecordClass(name=band)

physical_filter = universe["physical_filter"]
physical_filter_record = physical_filter.RecordClass(name=band, instrument="test", band=band)

patch_element = universe["patch"]
patch_record = patch_element.RecordClass(
skymap="test_skymap", tract=tract, patch=patch, cell_x=cell_x, cell_y=cell_y
)

# A dictionary with all the relevant records.
record = {
"instrument": instrument_record,
"patch": patch_record,
"tract": 9813,
"band": band_record.name,
"skymap": skymap_record.name,
"physical_filter": physical_filter_record,
}

# A dictionary with all the relevant recordIds.
record_id = record.copy()
for key in (
"instrument",
"physical_filter",
):
record_id[key] = record_id[key].name

data_id = DataCoordinate.standardize(record_id, universe=universe)
return data_id.expanded(record)