diff --git a/python/lsst/ip/diffim/dcrModel.py b/python/lsst/ip/diffim/dcrModel.py index c6def2062..6a953af27 100644 --- a/python/lsst/ip/diffim/dcrModel.py +++ b/python/lsst/ip/diffim/dcrModel.py @@ -25,6 +25,7 @@ from lsst.afw.coord import differentialRefraction import lsst.afw.image as afwImage import lsst.geom as geom +import lsst.pipe.base as pipeBase __all__ = ["DcrModel", "applyDcr", "calculateDcr", "calculateImageParallacticAngle"] @@ -471,6 +472,48 @@ def buildMatchedExposure(self, exposure=None, templateExposure.setPhotoCalib(self.photoCalib) return templateExposure + def buildMatchedExposureHandle(self, exposure=None, visitInfo=None, bbox=None, mask=None): + """Create in-memory butler dataset reference containing the DCR-matched + template. + + Parameters + ---------- + exposure : `lsst.afw.image.Exposure`, optional + The input exposure to build a matched template for. + May be omitted if all of the metadata is supplied separately + visitInfo : `lsst.afw.image.VisitInfo`, optional + Metadata for the exposure. Ignored if ``exposure`` is set. + bbox : `lsst.afw.geom.Box2I`, optional + Sub-region of the coadd, or use the entire coadd if not supplied. + mask : `lsst.afw.image.Mask`, optional + reference mask to use for the template image. + + Returns + ------- + templateExposureHandle: `lsst.pipe.base.InMemoryDatasetHandle` + In-memory butler dataset reference containing the DCR-matched + template. + + Raises + ------ + ValueError + If no `exposure` or `visitInfo` is set. + """ + if exposure is not None: + visitInfo = exposure.visitInfo + elif visitInfo is None: + raise ValueError("Either exposure or visitInfo must be set.") + templateExposure = self.buildMatchedExposure( + exposure=exposure, visitInfo=visitInfo, bbox=bbox, mask=mask + ) + templateExposureHandle = pipeBase.InMemoryDatasetHandle( + templateExposure, + storageClass="ExposureF", + copy=False, + photoCalib=self.photoCalib + ) + return templateExposureHandle + def conditionDcrModel(self, modelImages, bbox, gain=1.): """Average two iterations' solutions to reduce oscillations. diff --git a/python/lsst/ip/diffim/getTemplate.py b/python/lsst/ip/diffim/getTemplate.py index 1eab2d509..639f67026 100644 --- a/python/lsst/ip/diffim/getTemplate.py +++ b/python/lsst/ip/diffim/getTemplate.py @@ -31,20 +31,25 @@ import lsst.afw.math as afwMath import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase + from lsst.skymap import BaseSkyMap from lsst.ip.diffim.dcrModel import DcrModel from lsst.meas.algorithms import CoaddPsf, CoaddPsfConfig from lsst.utils.timer import timeMethod -__all__ = ["GetTemplateTask", "GetTemplateConfig", - "GetDcrTemplateTask", "GetDcrTemplateConfig"] +__all__ = [ + "GetTemplateTask", + "GetTemplateConfig", + "GetDcrTemplateTask", + "GetDcrTemplateConfig", +] -class GetTemplateConnections(pipeBase.PipelineTaskConnections, - dimensions=("instrument", "visit", "detector", "skymap"), - defaultTemplates={"coaddName": "goodSeeing", - "warpTypeSuffix": "", - "fakesType": ""}): +class GetTemplateConnections( + pipeBase.PipelineTaskConnections, + dimensions=("instrument", "visit", "detector", "skymap"), + defaultTemplates={"coaddName": "goodSeeing", "warpTypeSuffix": "", "fakesType": ""}, +): bbox = pipeBase.connectionTypes.Input( doc="Bounding box of exposure to determine the geometry of the output template.", name="{fakesType}calexp.bbox", @@ -60,18 +65,18 @@ class GetTemplateConnections(pipeBase.PipelineTaskConnections, skyMap = pipeBase.connectionTypes.Input( doc="Geometry of the tracts and patches that the coadds are defined on.", name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, - dimensions=("skymap", ), + dimensions=("skymap",), storageClass="SkyMap", ) coaddExposures = pipeBase.connectionTypes.Input( doc="Coadds that may overlap the desired region, as possible inputs to the template." - " Will be restricted to those that directly overlap the projected bounding box.", + " Will be restricted to those that directly overlap the projected bounding box.", dimensions=("tract", "patch", "skymap", "band"), storageClass="ExposureF", name="{fakesType}{coaddName}Coadd{warpTypeSuffix}", multiple=True, deferLoad=True, - deferGraphConstraint=True + deferGraphConstraint=True, ) template = pipeBase.connectionTypes.Output( @@ -82,12 +87,13 @@ class GetTemplateConnections(pipeBase.PipelineTaskConnections, ) -class GetTemplateConfig(pipeBase.PipelineTaskConfig, - pipelineConnections=GetTemplateConnections): +class GetTemplateConfig( + pipeBase.PipelineTaskConfig, pipelineConnections=GetTemplateConnections +): templateBorderSize = pexConfig.Field( dtype=int, default=20, - doc="Number of pixels to grow the requested template image to account for warping" + doc="Number of pixels to grow the requested template image to account for warping", ) warp = pexConfig.ConfigField( dtype=afwMath.Warper.ConfigClass, @@ -106,7 +112,7 @@ def setDefaults(self): # The WCS for LSST should be smoothly varying, so we can use a longer # interpolation length for WCS evaluations. self.warp.interpLength = 100 - self.warp.warpingKernelName = 'lanczos3' + self.warp.warpingKernelName = "lanczos3" self.coaddPsf.warpingKernelName = self.warp.warpingKernelName @@ -118,16 +124,25 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.warper = afwMath.Warper.fromConfig(self.config.warp) self.schema = afwTable.ExposureTable.makeMinimalSchema() - self.schema.addField('tract', type=np.int32, doc='Which tract this exposure came from.') - self.schema.addField('patch', type=np.int32, doc='Which patch in the tract this exposure came from.') - self.schema.addField('weight', type=float, - doc='Weight for each exposure, used to make the CoaddPsf; should always be 1.') + self.schema.addField( + "tract", type=np.int32, doc="Which tract this exposure came from." + ) + self.schema.addField( + "patch", + type=np.int32, + doc="Which patch in the tract this exposure came from.", + ) + self.schema.addField( + "weight", + type=float, + doc="Weight for each exposure, used to make the CoaddPsf; should always be 1.", + ) def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) bbox = inputs.pop("bbox") wcs = inputs.pop("wcs") - coaddExposures = inputs.pop('coaddExposures') + coaddExposures = inputs.pop("coaddExposures") skymap = inputs.pop("skyMap") # This should not happen with a properly configured execution context. @@ -135,21 +150,25 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): results = self.getExposures(coaddExposures, bbox, skymap, wcs) physical_filter = butlerQC.quantum.dataId["physical_filter"] - outputs = self.run(coaddExposureHandles=results.coaddExposures, - bbox=bbox, - wcs=wcs, - dataIds=results.dataIds, - physical_filter=physical_filter) + outputs = self.run( + coaddExposureHandles=results.coaddExposures, + bbox=bbox, + wcs=wcs, + dataIds=results.dataIds, + physical_filter=physical_filter, + ) butlerQC.put(outputs, outputRefs) - @deprecated(reason="Replaced by getExposures, which uses explicit arguments instead of a kwargs dict. " - "This method will be removed after v29.", - version="v29.0", category=FutureWarning) + @deprecated( + reason="Replaced by getExposures, which uses explicit arguments instead of a kwargs dict. " + "This method will be removed after v29.", + version="v29.0", + category=FutureWarning, + ) def getOverlappingExposures(self, inputs): - return self.getExposures(inputs["coaddExposures"], - inputs["bbox"], - inputs["skyMap"], - inputs["wcs"]) + return self.getExposures( + inputs["coaddExposures"], inputs["bbox"], inputs["skyMap"], inputs["wcs"] + ) def getExposures(self, coaddExposureHandles, bbox, skymap, wcs): """Return a data structure containing the coadds that overlap the @@ -199,7 +218,9 @@ def getExposures(self, coaddExposureHandles, bbox, skymap, wcs): WCS is None. """ if wcs is None: - raise pipeBase.NoWorkFound("WCS is None; cannot find overlapping exposures.") + raise pipeBase.NoWorkFound( + "WCS is None; cannot find overlapping exposures." + ) # Exposure's validPolygon would be more accurate detectorPolygon = geom.Box2D(bbox) @@ -209,21 +230,26 @@ def getExposures(self, coaddExposureHandles, bbox, skymap, wcs): for coaddRef in coaddExposureHandles: dataId = coaddRef.dataId - patchWcs = skymap[dataId['tract']].getWcs() - patchBBox = skymap[dataId['tract']][dataId['patch']].getOuterBBox() + patchWcs = skymap[dataId["tract"]].getWcs() + patchBBox = skymap[dataId["tract"]][dataId["patch"]].getOuterBBox() patchCorners = patchWcs.pixelToSky(geom.Box2D(patchBBox).getCorners()) patchPolygon = afwGeom.Polygon(wcs.skyToPixel(patchCorners)) if patchPolygon.intersection(detectorPolygon): - overlappingArea += patchPolygon.intersectionSingle(detectorPolygon).calculateArea() - self.log.info("Using template input tract=%s, patch=%s", dataId['tract'], dataId['patch']) - coaddExposures[dataId['tract']].append(coaddRef) - dataIds[dataId['tract']].append(dataId) + overlappingArea += patchPolygon.intersectionSingle( + detectorPolygon + ).calculateArea() + self.log.info( + "Using template input tract=%s, patch=%s", + dataId["tract"], + dataId["patch"], + ) + coaddExposures[dataId["tract"]].append(coaddRef) + dataIds[dataId["tract"]].append(dataId) if not overlappingArea: - raise pipeBase.NoWorkFound('No patches overlap detector') + raise pipeBase.NoWorkFound("No patches overlap detector") - return pipeBase.Struct(coaddExposures=coaddExposures, - dataIds=dataIds) + return pipeBase.Struct(coaddExposures=coaddExposures, dataIds=dataIds) @timeMethod def run(self, *, coaddExposureHandles, bbox, wcs, dataIds, physical_filter): @@ -276,25 +302,38 @@ def run(self, *, coaddExposureHandles, bbox, wcs, dataIds, physical_filter): warped = {} catalogs = [] for tract in coaddExposureHandles: - maskedImages, catalog, totalBox = self._makeExposureCatalog(coaddExposureHandles[tract], - dataIds[tract]) + maskedImages, catalog, totalBox = self._makeExposureCatalog( + coaddExposureHandles[tract], dataIds[tract] + ) warpedBox = computeWarpedBBox(catalog[0].wcs, bbox, wcs) warpedBox.grow(5) # to ensure we catch all relevant input pixels # Combine images from individual patches together. - unwarped, count, included = self._merge(maskedImages, warpedBox, catalog[0].wcs) + unwarped, count, included = self._merge( + maskedImages, warpedBox, catalog[0].wcs + ) # Delete `maskedImages` after combining into one large image to reduce peak memory use del maskedImages if count == 0: - self.log.info("No valid pixels from coadd patches in tract %s; not including in output.", - tract) + self.log.info( + "No valid pixels from coadd patches in tract %s; not including in output.", + tract, + ) continue warpedBox.clip(totalBox) - potentialInput = self.warper.warpExposure(wcs, unwarped.subset(warpedBox), destBBox=bbox) + potentialInput = self.warper.warpExposure( + wcs, unwarped.subset(warpedBox), destBBox=bbox + ) # Delete the single large `unwarped` image after warping to reduce peak memory use del unwarped - if np.all(potentialInput.mask.array & potentialInput.mask.getPlaneBitMask("NO_DATA")): - self.log.info("No overlap from coadd patches in tract %s; not including in output.", tract) + if np.all( + potentialInput.mask.array + & potentialInput.mask.getPlaneBitMask("NO_DATA") + ): + self.log.info( + "No overlap from coadd patches in tract %s; not including in output.", + tract, + ) continue # Trim the exposure catalog to just the patches that were used. @@ -334,6 +373,7 @@ def _checkInputs(dataIds, coaddExposures): Record of the tract and patch of each coaddExposure. coaddExposures : `dict` [`int`, `list` of \ [`lsst.daf.butler.DeferredDatasetHandle` of \ + `lsst.afw.image.Exposure` or `lsst.afw.image.Exposure`]] Coadds to be mosaicked. @@ -354,9 +394,11 @@ def _checkInputs(dataIds, coaddExposures): if len(bands) > 1: raise RuntimeError(f"GetTemplateTask called with multiple bands: {bands}") band = bands.pop() - photoCalibs = [exposure.get(component="photoCalib") - for exposures in coaddExposures.values() - for exposure in exposures] + photoCalibs = [ + exposure.get(component="photoCalib") + for exposures in coaddExposures.values() + for exposure in exposures + ] if not all([photoCalibs[0] == x for x in photoCalibs]): msg = f"GetTemplateTask called with exposures with different photoCalibs: {photoCalibs}" raise RuntimeError(msg) @@ -447,8 +489,10 @@ def _merge(self, maskedImages, bbox, wcs): continue # nothing in this image overlaps the output maskedImage = maskedImage.subset(clippedBox) # Catch both zero-value and NaN variance plane pixels - good = (maskedImage.variance.array > 0) & (np.isfinite(maskedImage.variance.array)) - weight = maskedImage.variance.array[good]**(-0.5) + good = (maskedImage.variance.array > 0) & ( + np.isfinite(maskedImage.variance.array) + ) + weight = maskedImage.variance.array[good] ** (-0.5) bad = np.isnan(maskedImage.image.array) | ~good # Note that modifying the patch MaskedImage in place is fine; # we're throwing it away at the end anyway. @@ -505,20 +549,22 @@ def _makePsf(self, template, catalog, wcs): """ # CoaddPsf centroid not only must overlap image, but must overlap the # part of image with data. Use centroid of region with data. - boolmask = template.mask.array & template.mask.getPlaneBitMask('NO_DATA') == 0 + boolmask = template.mask.array & template.mask.getPlaneBitMask("NO_DATA") == 0 maskx = afwImage.makeMaskFromArray(boolmask.astype(afwImage.MaskPixel)) centerCoord = afwGeom.SpanSet.fromMask(maskx, 1).computeCentroid() ctrl = self.config.coaddPsf.makeControl() - coaddPsf = CoaddPsf(catalog, wcs, centerCoord, ctrl.warpingKernelName, ctrl.cacheSize) + coaddPsf = CoaddPsf( + catalog, wcs, centerCoord, ctrl.warpingKernelName, ctrl.cacheSize + ) return coaddPsf -class GetDcrTemplateConnections(GetTemplateConnections, - dimensions=("instrument", "visit", "detector", "skymap"), - defaultTemplates={"coaddName": "dcr", - "warpTypeSuffix": "", - "fakesType": ""}): +class GetDcrTemplateConnections( + GetTemplateConnections, + dimensions=("instrument", "visit", "detector", "skymap"), + defaultTemplates={"coaddName": "dcr", "warpTypeSuffix": "", "fakesType": ""}, +): visitInfo = pipeBase.connectionTypes.Input( doc="VisitInfo of calexp used to determine observing conditions.", name="{fakesType}calexp.visitInfo", @@ -531,7 +577,7 @@ class GetDcrTemplateConnections(GetTemplateConnections, storageClass="ExposureF", dimensions=("tract", "patch", "skymap", "band", "subfilter"), multiple=True, - deferLoad=True + deferLoad=True, ) def __init__(self, *, config=None): @@ -539,8 +585,9 @@ def __init__(self, *, config=None): self.inputs.remove("coaddExposures") -class GetDcrTemplateConfig(GetTemplateConfig, - pipelineConnections=GetDcrTemplateConnections): +class GetDcrTemplateConfig( + GetTemplateConfig, pipelineConnections=GetDcrTemplateConnections +): numSubfilters = pexConfig.Field( doc="Number of subfilters in the DcrCoadd.", dtype=int, @@ -559,9 +606,11 @@ class GetDcrTemplateConfig(GetTemplateConfig, def validate(self): if self.effectiveWavelength is None or self.bandwidth is None: - raise ValueError("The effective wavelength and bandwidth of the physical filter " - "must be set in the getTemplate config for DCR coadds. " - "Required until transmission curves are used in DM-13668.") + raise ValueError( + "The effective wavelength and bandwidth of the physical filter " + "must be set in the getTemplate config for DCR coadds. " + "Required until transmission curves are used in DM-13668." + ) class GetDcrTemplateTask(GetTemplateTask): @@ -572,31 +621,40 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) bbox = inputs.pop("bbox") wcs = inputs.pop("wcs") - dcrCoaddExposureHandles = inputs.pop('dcrCoadds') + dcrCoaddExposureHandles = inputs.pop("dcrCoadds") skymap = inputs.pop("skyMap") visitInfo = inputs.pop("visitInfo") # This should not happen with a properly configured execution context. assert not inputs, "runQuantum got more inputs than expected" - results = self.getExposures(dcrCoaddExposureHandles, bbox, skymap, wcs, visitInfo) + results = self.getExposures( + dcrCoaddExposureHandles, bbox, skymap, wcs, visitInfo + ) physical_filter = butlerQC.quantum.dataId["physical_filter"] - outputs = self.run(coaddExposures=results.coaddExposures, - bbox=bbox, - wcs=wcs, - dataIds=results.dataIds, - physical_filter=physical_filter) + outputs = self.run( + coaddExposureHandles=results.coaddExposures, + bbox=bbox, + wcs=wcs, + dataIds=results.dataIds, + physical_filter=physical_filter, + ) butlerQC.put(outputs, outputRefs) - @deprecated(reason="Replaced by getExposures, which uses explicit arguments instead of a kwargs dict. " - "This method will be removed after v29.", - version="v29.0", category=FutureWarning) + @deprecated( + reason="Replaced by getExposures, which uses explicit arguments instead of a kwargs dict. " + "This method will be removed after v29.", + version="v29.0", + category=FutureWarning, + ) def getOverlappingExposures(self, inputs): - return self.getExposures(inputs["dcrCoadds"], - inputs["bbox"], - inputs["skyMap"], - inputs["wcs"], - inputs["visitInfo"]) + return self.getExposures( + inputs["dcrCoadds"], + inputs["bbox"], + inputs["skyMap"], + inputs["wcs"], + inputs["visitInfo"], + ) def getExposures(self, dcrCoaddExposureHandles, bbox, skymap, wcs, visitInfo): """Return lists of coadds and their corresponding dataIds that overlap @@ -653,28 +711,35 @@ def getExposures(self, dcrCoaddExposureHandles, bbox, skymap, wcs, visitInfo): patchList = dict() for coaddRef in dcrCoaddExposureHandles: dataId = coaddRef.dataId - patchWcs = skymap[dataId['tract']].getWcs() - patchBBox = skymap[dataId['tract']][dataId['patch']].getOuterBBox() + subfilter = dataId["subfilter"] + patchWcs = skymap[dataId["tract"]].getWcs() + patchBBox = skymap[dataId["tract"]][dataId["patch"]].getOuterBBox() patchCorners = patchWcs.pixelToSky(geom.Box2D(patchBBox).getCorners()) patchPolygon = afwGeom.Polygon(wcs.skyToPixel(patchCorners)) if patchPolygon.intersection(detectorPolygon): - overlappingArea += patchPolygon.intersectionSingle(detectorPolygon).calculateArea() - self.log.info("Using template input tract=%s, patch=%s, subfilter=%s" % - (dataId['tract'], dataId['patch'], dataId["subfilter"])) - if dataId['tract'] in patchList: - patchList[dataId['tract']].append(dataId['patch']) + overlappingArea += patchPolygon.intersectionSingle( + detectorPolygon + ).calculateArea() + self.log.info( + "Using template input tract=%s, patch=%s, subfilter=%s" + % (dataId["tract"], dataId["patch"], dataId["subfilter"]) + ) + if dataId["tract"] in patchList: + patchList[dataId["tract"]].append(dataId["patch"]) else: - patchList[dataId['tract']] = [dataId['patch'], ] - dataIds[dataId['tract']].append(dataId) + patchList[dataId["tract"]] = [ + dataId["patch"], + ] + if subfilter == 0: + dataIds[dataId["tract"]].append(dataId) if not overlappingArea: - raise pipeBase.NoWorkFound('No patches overlap detector') + raise pipeBase.NoWorkFound("No patches overlap detector") self.checkPatchList(patchList) coaddExposures = self.getDcrModel(patchList, dcrCoaddExposureHandles, visitInfo) - return pipeBase.Struct(coaddExposures=coaddExposures, - dataIds=dataIds) + return pipeBase.Struct(coaddExposures=coaddExposures, dataIds=dataIds) def checkPatchList(self, patchList): """Check that all of the DcrModel subfilters are present for each @@ -694,8 +759,11 @@ def checkPatchList(self, patchList): for tract in patchList: for patch in set(patchList[tract]): if patchList[tract].count(patch) != self.config.numSubfilters: - raise RuntimeError("Invalid number of DcrModel subfilters found: %d vs %d expected", - patchList[tract].count(patch), self.config.numSubfilters) + raise RuntimeError( + "Invalid number of DcrModel subfilters found: %d vs %d expected", + patchList[tract].count(patch), + self.config.numSubfilters, + ) def getDcrModel(self, patchList, coaddRefs, visitInfo): """Build DCR-matched coadds from a list of exposure references. @@ -718,17 +786,24 @@ def getDcrModel(self, patchList, coaddRefs, visitInfo): coaddExposures = collections.defaultdict(list) for tract in patchList: for patch in set(patchList[tract]): - coaddRefList = [coaddRef for coaddRef in coaddRefs - if _selectDataRef(coaddRef, tract, patch)] - - dcrModel = DcrModel.fromQuantum(coaddRefList, - self.config.effectiveWavelength, - self.config.bandwidth, - self.config.numSubfilters) - coaddExposures[tract].append(dcrModel.buildMatchedExposure(visitInfo=visitInfo)) + coaddRefList = [ + coaddRef + for coaddRef in coaddRefs + if _selectDataRef(coaddRef, tract, patch) + ] + + dcrModel = DcrModel.fromQuantum( + coaddRefList, + self.config.effectiveWavelength, + self.config.bandwidth, + self.config.numSubfilters, + ) + coaddExposures[tract].append(dcrModel.buildMatchedExposureHandle(visitInfo=visitInfo)) return coaddExposures def _selectDataRef(coaddRef, tract, patch): - condition = (coaddRef.dataId['tract'] == tract) & (coaddRef.dataId['patch'] == patch) + condition = (coaddRef.dataId["tract"] == tract) & ( + coaddRef.dataId["patch"] == patch + ) return condition