From 090207bbf6c72d6f6690dc7a5a9bd759ccd39af6 Mon Sep 17 00:00:00 2001 From: John Parejko Date: Thu, 13 Mar 2025 01:51:48 -0700 Subject: [PATCH 1/4] Restrict getTemplate per-tract handling to a trimmed, warped bbox This should significantly cut memory usage, as it will only make unwarped exposures as large as necessary, not multiple patches in size. --- python/lsst/ip/diffim/getTemplate.py | 53 +++++++++++++++++++--------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/python/lsst/ip/diffim/getTemplate.py b/python/lsst/ip/diffim/getTemplate.py index 1904bc3d7..b82cfb2b8 100644 --- a/python/lsst/ip/diffim/getTemplate.py +++ b/python/lsst/ip/diffim/getTemplate.py @@ -27,6 +27,7 @@ import lsst.geom as geom import lsst.afw.geom as afwGeom import lsst.afw.table as afwTable +from lsst.afw.math._warper import computeWarpedBBox import lsst.afw.math as afwMath import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase @@ -277,25 +278,33 @@ def run(self, *, coaddExposureHandles, bbox, wcs, dataIds, physical_filter): for tract in coaddExposureHandles: maskedImages, catalog, totalBox = self._makeExposureCatalog(coaddExposureHandles[tract], dataIds[tract]) + warpedBox = computeWarpedBBox(catalog[0].wcs, bbox, wcs) + warpedBox.grow(5) # to ensure we catch all relevant input pixels # Combine images from individual patches together. - unwarped = self._merge(maskedImages, totalBox, catalog[0].wcs) + unwarped, count = self._merge(maskedImages, warpedBox, catalog[0].wcs) # Delete `maskedImages` after combining into one large image to reduce peak memory use del maskedImages - potentialInput = self.warper.warpExposure(wcs, unwarped, destBBox=bbox) + if count == 0: + self.log.info("No valid pixels from coadd patches in tract %s; not including in output.", + tract) + continue + warpedBox.clip(totalBox) + potentialInput = self.warper.warpExposure(wcs, unwarped.subset(warpedBox), destBBox=bbox) # Delete the single large `unwarped` image after warping to reduce peak memory use del unwarped - if not np.any(np.isfinite(potentialInput.image.array)): - self.log.info("No overlap from coadds in tract %s; not including in output.", tract) + if np.all(potentialInput.mask.array & potentialInput.mask.getPlaneBitMask("NO_DATA")): + self.log.info("No overlap from coadd patches in tract %s; not including in output.", tract) continue catalogs.append(catalog) - warped[tract] = potentialInput - warped[tract].setWcs(wcs) + warped[tract] = potentialInput.maskedImage if len(warped) == 0: raise pipeBase.NoWorkFound("No patches found to overlap science exposure.") - template = self._merge([x.maskedImage for x in warped.values()], bbox, wcs) + template, count = self._merge(warped, bbox, wcs) + if count == 0: + raise pipeBase.NoWorkFound("No valid pixels in warped template.") # Make a single catalog containing all the inputs that were accepted. catalog = afwTable.ExposureCatalog(self.schema) @@ -362,7 +371,7 @@ def _makeExposureCatalog(self, exposureRefs, dataIds): Returns ------- - images : `list` [`lsst.afw.image.MaskedImage`] + images : `dict` [`lsst.afw.image.MaskedImage`] MaskedImages of each of the input exposures, for warping. catalog : `lsst.afw.table.ExposureCatalog` Catalog of metadata for each exposure @@ -372,11 +381,11 @@ def _makeExposureCatalog(self, exposureRefs, dataIds): catalog = afwTable.ExposureCatalog(self.schema) catalog.reserve(len(exposureRefs)) exposures = (exposureRef.get() for exposureRef in exposureRefs) - images = [] + images = {} totalBox = geom.Box2I() for coadd, dataId in zip(exposures, dataIds): - images.append(coadd.maskedImage) + images[dataId] = coadd.maskedImage bbox = coadd.getBBox() totalBox = totalBox.expandedTo(bbox) record = catalog.addNew() @@ -393,15 +402,15 @@ def _makeExposureCatalog(self, exposureRefs, dataIds): return images, catalog, totalBox - @staticmethod - def _merge(maskedImages, bbox, wcs): + def _merge(self, maskedImages, bbox, wcs): """Merge the images that came from one tract into one larger image, ignoring NaN pixels and non-finite variance pixels from individual exposures. Parameters ---------- - maskedImages : `list` [`lsst.afw.image.MaskedImage`] + maskedImages : `dict` [`lsst.afw.image.MaskedImage` or + `lsst.afw.image.Exposure`] Images to be merged into one larger bounding box. bbox : `lsst.geom.Box2I` Bounding box defining the image to merge into. @@ -413,10 +422,20 @@ def _merge(maskedImages, bbox, wcs): merged : `lsst.afw.image.MaskedImage` Merged image with all of the inputs at their respective bbox positions. + count : `int` + Count of the number of good pixels (those with positive weights) + in the merged image. """ merged = afwImage.ExposureF(bbox, wcs) weights = afwImage.ImageF(bbox) - for maskedImage in maskedImages: + for dataId, maskedImage in maskedImages.items(): + # Only merge into the trimmed box, to save memory + clippedBox = geom.Box2I(maskedImage.getBBox()) + clippedBox.clip(bbox) + if clippedBox.area == 0: + self.log.debug("%s does not overlap template region.", dataId) + continue # nothing in this image overlaps the output + maskedImage = maskedImage.subset(clippedBox) # Catch both zero-value and NaN variance plane pixels good = (maskedImage.variance.array > 0) & (np.isfinite(maskedImage.variance.array)) weight = maskedImage.variance.array[good]**(-0.5) @@ -432,10 +451,10 @@ def _merge(maskedImages, bbox, wcs): # `weight` are the exact values we want to scale by. maskedImage.image.array[good] *= weight maskedImage.variance.array[good] *= weight - weights[maskedImage.getBBox()].array[good] += weight + weights[clippedBox].array[good] += weight # Free memory before creating new large arrays del weight - merged.maskedImage[maskedImage.getBBox()] += maskedImage + merged.maskedImage[clippedBox] += maskedImage good = weights.array > 0 @@ -448,7 +467,7 @@ def _merge(maskedImages, bbox, wcs): merged.mask.array[~good] |= merged.mask.getPlaneBitMask("NO_DATA") - return merged + return merged, good.sum() def _makePsf(self, template, catalog, wcs): """Return a PSF containing the PSF at each of the input regions. From 5a961067ff7625d1e6e738ba3611181e707a1379 Mon Sep 17 00:00:00 2001 From: John Parejko Date: Thu, 13 Mar 2025 01:52:56 -0700 Subject: [PATCH 2/4] Use a real DataCoordinate in mocked data This is necessary to use the dataId as a dict key. --- tests/test_getTemplate.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_getTemplate.py b/tests/test_getTemplate.py index 20d174ced..8afe4120d 100644 --- a/tests/test_getTemplate.py +++ b/tests/test_getTemplate.py @@ -28,6 +28,7 @@ import lsst.afw.geom import lsst.afw.image import lsst.afw.math +from lsst.daf.butler import DataCoordinate, DimensionUniverse import lsst.geom import lsst.ip.diffim import lsst.meas.algorithms @@ -168,9 +169,12 @@ def _makePatches(self, tract): ) ) self.patches[tract.tract_id].append(dataRef) - self.dataIds[tract.tract_id].append({"tract": tract.tract_id, - "patch": patchId, - "band": "a"}) + dataCoordinate = DataCoordinate.standardize({"tract": tract.tract_id, + "patch": patchId, + "band": "a", + "skymap": "skymap"}, + universe=DimensionUniverse()) + self.dataIds[tract.tract_id].append(dataCoordinate) def _checkMetadata(self, template, config, box, wcs, nInputs): """Check that the various metadata components were set correctly. From 4e5e9693bba5d1b925b3e3664100a86766d56f5b Mon Sep 17 00:00:00 2001 From: John Parejko Date: Fri, 14 Mar 2025 10:44:38 -0700 Subject: [PATCH 3/4] Improve psf handling We can trim the exposure catalog to only include patches that were used, thus reducing the number of entries in the final CoaddPsf. --- python/lsst/ip/diffim/getTemplate.py | 21 ++++++++++++++++----- tests/test_getTemplate.py | 12 +++++++----- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/python/lsst/ip/diffim/getTemplate.py b/python/lsst/ip/diffim/getTemplate.py index b82cfb2b8..1eab2d509 100644 --- a/python/lsst/ip/diffim/getTemplate.py +++ b/python/lsst/ip/diffim/getTemplate.py @@ -281,7 +281,7 @@ def run(self, *, coaddExposureHandles, bbox, wcs, dataIds, physical_filter): warpedBox = computeWarpedBBox(catalog[0].wcs, bbox, wcs) warpedBox.grow(5) # to ensure we catch all relevant input pixels # Combine images from individual patches together. - unwarped, count = self._merge(maskedImages, warpedBox, catalog[0].wcs) + unwarped, count, included = self._merge(maskedImages, warpedBox, catalog[0].wcs) # Delete `maskedImages` after combining into one large image to reduce peak memory use del maskedImages if count == 0: @@ -297,12 +297,18 @@ def run(self, *, coaddExposureHandles, bbox, wcs, dataIds, physical_filter): self.log.info("No overlap from coadd patches in tract %s; not including in output.", tract) continue - catalogs.append(catalog) + # Trim the exposure catalog to just the patches that were used. + tempCatalog = afwTable.ExposureCatalog(self.schema) + tempCatalog.reserve(len(included)) + for i in included: + tempCatalog.append(catalog[i]) + catalogs.append(tempCatalog) warped[tract] = potentialInput.maskedImage if len(warped) == 0: raise pipeBase.NoWorkFound("No patches found to overlap science exposure.") - template, count = self._merge(warped, bbox, wcs) + # At this point, all entries will be valid, so we can ignore included. + template, count, _ = self._merge(warped, bbox, wcs) if count == 0: raise pipeBase.NoWorkFound("No valid pixels in warped template.") @@ -425,10 +431,14 @@ def _merge(self, maskedImages, bbox, wcs): count : `int` Count of the number of good pixels (those with positive weights) in the merged image. + included : `list` [`int`] + List of indexes of patches that were included in the merged + result, to be used to trim the exposure catalog. """ merged = afwImage.ExposureF(bbox, wcs) weights = afwImage.ImageF(bbox) - for dataId, maskedImage in maskedImages.items(): + included = [] # which patches were included in the result + for i, (dataId, maskedImage) in enumerate(maskedImages.items()): # Only merge into the trimmed box, to save memory clippedBox = geom.Box2I(maskedImage.getBBox()) clippedBox.clip(bbox) @@ -455,6 +465,7 @@ def _merge(self, maskedImages, bbox, wcs): # Free memory before creating new large arrays del weight merged.maskedImage[clippedBox] += maskedImage + included.append(i) good = weights.array > 0 @@ -467,7 +478,7 @@ def _merge(self, maskedImages, bbox, wcs): merged.mask.array[~good] |= merged.mask.getPlaneBitMask("NO_DATA") - return merged, good.sum() + return merged, good.sum(), included def _makePsf(self, template, catalog, wcs): """Return a PSF containing the PSF at each of the input regions. diff --git a/tests/test_getTemplate.py b/tests/test_getTemplate.py index 8afe4120d..4adb91850 100644 --- a/tests/test_getTemplate.py +++ b/tests/test_getTemplate.py @@ -176,7 +176,7 @@ def _makePatches(self, tract): universe=DimensionUniverse()) self.dataIds[tract.tract_id].append(dataCoordinate) - def _checkMetadata(self, template, config, box, wcs, nInputs): + def _checkMetadata(self, template, config, box, wcs, nPsfs): """Check that the various metadata components were set correctly. """ expectedBox = lsst.geom.Box2I(box) @@ -190,7 +190,7 @@ def _checkMetadata(self, template, config, box, wcs, nInputs): self.assertEqual(template.getXY0(), expectedBox.getMin()) self.assertEqual(template.filter.bandLabel, "a") self.assertEqual(template.filter.physicalLabel, "a_test") - self.assertEqual(template.psf.getComponentCount(), nInputs) + self.assertEqual(template.psf.getComponentCount(), nPsfs) def _checkPixels(self, template, config, box): """Check that the pixel values in the template are close to the @@ -250,7 +250,7 @@ def testRunOneTractMultipleInputs(self): physical_filter="a_test") # All 4 patches from two tracts are included in this template. - self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 8) + self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 6) self._checkPixels(result.template, task.config, box) def testRunTwoTracts(self): @@ -266,7 +266,7 @@ def testRunTwoTracts(self): physical_filter="a_test") # All 4 patches from all 4 tracts are included in this template - self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 16) + self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 9) self._checkPixels(result.template, task.config, box) def testRunNoTemplate(self): @@ -326,7 +326,9 @@ def testNanInputs(self, box=None, nInput=None): wcs=self.exposure.wcs, dataIds=self.dataIds, physical_filter="a_test") - self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 16) + if debug: + _showTemplate(box, result.template) + self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 9) # We just check that the pixel values are all finite. We cannot check that pixel values # in the template are closer to the original anymore. self.assertTrue(np.isfinite(result.template.image.array).all()) From ff976f62b8299bc760c3da8c79f843348e9a4362 Mon Sep 17 00:00:00 2001 From: John Parejko Date: Fri, 14 Mar 2025 10:44:43 -0700 Subject: [PATCH 4/4] Reflow docstring --- tests/test_getTemplate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_getTemplate.py b/tests/test_getTemplate.py index 4adb91850..4fcba1ac4 100644 --- a/tests/test_getTemplate.py +++ b/tests/test_getTemplate.py @@ -310,7 +310,9 @@ def testMissingPatches(self): nInput=[8, 16], ) def testNanInputs(self, box=None, nInput=None): - """Test that the template has finite values when some of the input pixels have NaN as variance.""" + """Test that the template has finite values when some of the input + pixels have NaN as variance. + """ for tract, patchRefs in self.patches.items(): for patchRef in patchRefs: patchCoadd = patchRef.get()