diff --git a/python/lsst/ip/diffim/detectAndMeasure.py b/python/lsst/ip/diffim/detectAndMeasure.py index bcba5a2c7..a263106f9 100644 --- a/python/lsst/ip/diffim/detectAndMeasure.py +++ b/python/lsst/ip/diffim/detectAndMeasure.py @@ -20,6 +20,8 @@ # along with this program. If not, see . import numpy as np +import requests +import os import lsst.afw.detection as afwDetection import lsst.afw.image as afwImage @@ -27,7 +29,8 @@ import lsst.afw.table as afwTable import lsst.daf.base as dafBase import lsst.geom -from lsst.ip.diffim.utils import evaluateMaskFraction, computeDifferenceImageMetrics +from lsst.ip.diffim.utils import (evaluateMaskFraction, computeDifferenceImageMetrics, + populate_sattle_visit_cache) from lsst.meas.algorithms import SkyObjectsTask, SourceDetectionTask, SetPrimaryFlagsTask, MaskStreaksTask from lsst.meas.base import ForcedMeasurementTask, ApplyApCorrTask, DetectorVisitIdGeneratorConfig import lsst.meas.deblender @@ -303,6 +306,19 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig, default=True, doc="Raise an algorithm error if no diaSources are detected.", ) + run_sattle = pexConfig.Field( + dtype=bool, + default=False, + doc="If true, dia source bounding boxes will be sent for verification" + "to the sattle service." + ) + sattle_historical = pexConfig.Field( + dtype=bool, + default=False, + doc="If re-running a pipeline that requires sattle, this should be set " + "to True. This will populate sattle's cache with the historic data " + "closest in time to the exposure." + ) idGenerator = DetectorVisitIdGeneratorConfig.make_field() def setDefaults(self): @@ -375,6 +391,15 @@ def setDefaults(self): "STREAK", "INJECTED", "INJECTED_TEMPLATE"] self.skySources.avoidMask = ["DETECTED", "DETECTED_NEGATIVE", "BAD", "NO_DATA", "EDGE"] + def validate(self): + super().validate() + + if self.run_sattle: + if not os.getenv("SATTLE_URI_BASE"): + raise pexConfig.FieldValidationError(DetectAndMeasureConfig.run_sattle, self, + "Sattle requested but SATTLE_URI_BASE " + "environment variable not set.") + class DetectAndMeasureTask(lsst.pipe.base.PipelineTask): """Detect and measure sources on a difference image. @@ -692,6 +717,9 @@ def processResults(self, science, matchedTemplate, difference, sources, idFactor self.measureDiaSources(initialDiaSources, science, difference, matchedTemplate) diaSources = self._removeBadSources(initialDiaSources) + if self.config.run_sattle: + diaSources = self.filterSatellites(diaSources, science) + if self.config.doForcedMeasurement: self.measureForcedSources(diaSources, science, difference.getWcs()) @@ -949,6 +977,81 @@ def calculateMetrics(self, science, difference, diaSources, kernelSources): raise BadSubtractionError(ratio=metrics.differenceFootprintRatioStdev, threshold=self.config.badSubtractionVariationThreshold) + def getSattleDiaSourceAllowlist(self, diaSources, science): + """Query the sattle service and determine which diaSources are allowed. + + Parameters + ---------- + diaSources : `lsst.afw.table.SourceCatalog` + The catalog of detected sources. + science : `lsst.afw.image.ExposureF` + Science exposure that was subtracted. + + Returns + ---------- + allow_list : `list` of `int` + diaSourceIds of diaSources that can be made public. + + Raises + ------ + requests.HTTPError + Raised if sattle call does not return success. + """ + wcs = science.getWcs() + visit_info = science.getInfo().getVisitInfo() + visit_id = visit_info.getId() + sattle_uri_base = os.getenv('SATTLE_URI_BASE') + + dia_sources_json = [] + for source in diaSources: + source_bbox = source.getFootprint().getBBox() + corners = wcs.pixelToSky([lsst.geom.Point2D(c) for c in source_bbox.getCorners()]) + bbox_radec = [[pt.getRa().asDegrees(), pt.getDec().asDegrees()] for pt in corners] + dia_sources_json.append({"diasource_id": source["id"], "bbox": bbox_radec}) + + payload = {"visit_id": visit_id, "detector_id": science.getDetector(), "diasources": dia_sources_json, + "historical": self.config.sattle_historical} + + sattle_output = requests.put(f'{sattle_uri_base}/diasource_allow_list', + json=payload) + + # retry once if visit cache is not populated + if sattle_output.status_code == 404: + self.log.warning(f'Visit {visit_id} not found in sattle cache, re-sending') + populate_sattle_visit_cache(visit_info, historical=self.config.sattle_historical) + sattle_output = requests.put(f'{sattle_uri_base}/diasource_allow_list', json=payload) + + sattle_output.raise_for_status() + + return sattle_output.json()['allow_list'] + + def filterSatellites(self, diaSources, science): + """Remove diaSources overlapping predicted satellite positions. + + Parameters + ---------- + diaSources : `lsst.afw.table.SourceCatalog` + The catalog of detected sources. + science : `lsst.afw.image.ExposureF` + Science exposure that was subtracted. + + Returns + ---------- + filterdDiaSources : `lsst.afw.table.SourceCatalog` + Filtered catalog of diaSources + """ + + allow_list = self.getSattleDiaSourceAllowlist(diaSources, science) + + if allow_list: + allow_set = set(allow_list) + allowed_ids = [source['id'] in allow_set for source in diaSources] + diaSources = diaSources[np.array(allowed_ids)].copy(deep=True) + else: + self.log.warning('Sattle allowlist is empty, all diaSources removed') + diaSources = diaSources[0:0].copy(deep=True) + return diaSources + def _runStreakMasking(self, difference): """Do streak masking and optionally save the resulting streak fit parameters in a catalog. diff --git a/python/lsst/ip/diffim/utils.py b/python/lsst/ip/diffim/utils.py index 9648c8f53..fc0ade02e 100644 --- a/python/lsst/ip/diffim/utils.py +++ b/python/lsst/ip/diffim/utils.py @@ -28,6 +28,8 @@ import itertools import numpy as np +import os +import requests import lsst.geom as geom import lsst.afw.detection as afwDetection import lsst.afw.image as afwImage @@ -415,3 +417,40 @@ def footprint_mean(sources, sky=0): differenceFootprintSkyRatioMean=sky_mean, differenceFootprintSkyRatioStdev=sky_std, ) + + +def populate_sattle_visit_cache(visit_info, historical=False): + """Populate a cache of predicted satellite positions in the sattle service. + + Parameters + ---------- + visit_info: `lsst.afw.table.ExposureRecord.visitInfo` + Visit info for the science exposure being processed. + historical: `bool` + Set to True if observations are older than the current day. + + Raises + ------ + requests.HTTPError + Raised if sattle call does not return success. + """ + + visit_mjd = visit_info.getDate().toAstropy().mjd + + exposure_time_days = visit_info.getExposureTime() / 86400.0 + exposure_end_mjd = visit_mjd + exposure_time_days / 2.0 + exposure_start_mjd = visit_mjd - exposure_time_days / 2.0 + + boresight_ra = visit_info.boresightRaDec.getRa().asDegrees() + boresight_dec = visit_info.boresightRaDec.getDec().asDegrees() + + r = requests.put( + f'{os.getenv("SATTLE_URI_BASE")}/visit_cache', + json={"visit_id": visit_info.getId(), + "exposure_start_mjd": exposure_start_mjd, + "exposure_end_mjd": exposure_end_mjd, + "boresight_ra": boresight_ra, + "boresight_dec": boresight_dec, + "historical": historical}) + + r.raise_for_status() diff --git a/tests/test_detectAndMeasure.py b/tests/test_detectAndMeasure.py index 43af77ff1..2b2b4b8ac 100644 --- a/tests/test_detectAndMeasure.py +++ b/tests/test_detectAndMeasure.py @@ -20,17 +20,26 @@ # along with this program. If not, see . import numpy as np +import os import unittest +from unittest import mock +import requests import lsst.afw.geom as afwGeom import lsst.afw.image as afwImage import lsst.afw.math as afwMath import lsst.geom from lsst.ip.diffim import detectAndMeasure, subtractImages +from lsst.afw.table import IdFactory +from lsst.afw.cameraGeom.testUtils import DetectorWrapper import lsst.meas.algorithms as measAlg from lsst.pipe.base import InvalidQuantumError, UpstreamFailureNoWorkFound, AlgorithmError import lsst.utils.tests import lsst.meas.base.tests +import lsst.daf.base as dafBase +from lsst.afw.coord import Observatory, Weather +import lsst.geom as geom +import lsst.pex.config as pexConfig from utils import makeTestImage, checkMask @@ -101,7 +110,8 @@ def _check_values(self, values, minValue=None, maxValue=None): if maxValue is not None: self.assertTrue(np.all(values <= maxValue)) - def _setup_detection(self, doSkySources=True, nSkySources=5, doSubtractBackground=False, **kwargs): + def _setup_detection(self, doSkySources=True, nSkySources=5, + doSubtractBackground=False, run_sattle=False, **kwargs): """Setup and configure the detection and measurement PipelineTask. Parameters @@ -124,6 +134,8 @@ def _setup_detection(self, doSkySources=True, nSkySources=5, doSubtractBackgroun config.skySources.nSources = nSkySources config.update(**kwargs) + config.run_sattle = run_sattle + # Make a realistic id generator so that output catalog ids are useful. dataId = lsst.daf.butler.DataCoordinate.standardize( instrument="I", @@ -154,6 +166,7 @@ def test_detection_xy0(self): fluxLevel = 500 kwargs = {"seed": staticSeed, "psfSize": 2.4, "fluxLevel": fluxLevel, "x0": 12345, "y0": 67890} science, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6, **kwargs) + science.getInfo().setVisitInfo(makeVisitInfo()) matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs) difference = science.clone() @@ -223,6 +236,7 @@ def test_measurements_finite(self): "xSize": xSize, "ySize": ySize} science, sources = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel, noiseSeed=6, nSrc=1, **kwargs) + science.getInfo().setVisitInfo(makeVisitInfo()) matchedTemplate, _ = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel/4, noiseSeed=7, nSrc=1, **kwargs) rng = np.random.RandomState(3) @@ -269,6 +283,7 @@ def test_remove_unphysical(self): kwargs = {"psfSize": 2.4, "xSize": xSize, "ySize": ySize} science, sources = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel, noiseSeed=6, nSrc=1, **kwargs) + science.getInfo().setVisitInfo(makeVisitInfo()) matchedTemplate, _ = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel/4, noiseSeed=7, nSrc=1, **kwargs) transients, transientSources = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=8, nSrc=1, **kwargs) @@ -403,6 +418,7 @@ def test_missing_mask_planes(self): kwargs = {"psfSize": 2.4, "fluxLevel": fluxLevel, "addMaskPlanes": []} # Use different seeds for the science and template so every source is a diaSource science, sources = makeTestImage(seed=5, noiseLevel=noiseLevel, noiseSeed=6, **kwargs) + science.getInfo().setVisitInfo(makeVisitInfo()) matchedTemplate, _ = makeTestImage(seed=6, noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs) difference = science.clone() @@ -434,6 +450,7 @@ def test_detect_dipoles(self): "xSize": xSize, "ySize": ySize} dipoleFlag = "ip_diffim_DipoleFit_classification" science, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6, **kwargs) + science.getInfo().setVisitInfo(makeVisitInfo()) matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs) difference = science.clone() matchedTemplate.image.array[...] = np.roll(matchedTemplate.image.array[...], offset, axis=0) @@ -473,6 +490,7 @@ def test_sky_sources(self): fluxLevel = 500 kwargs = {"seed": staticSeed, "psfSize": 2.4, "fluxLevel": fluxLevel} science, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6, **kwargs) + science.getInfo().setVisitInfo(makeVisitInfo()) matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs) transients, transientSources = makeTestImage(seed=transientSeed, psfSize=2.4, nSrc=10, fluxLevel=transientFluxLevel, @@ -517,6 +535,9 @@ def test_exclude_mask_detections(self): radius = 2 kwargs = {"seed": staticSeed, "psfSize": 2.4, "fluxLevel": fluxLevel} science, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6, **kwargs) + science.getInfo().setVisitInfo(makeVisitInfo()) + detector = DetectorWrapper(numAmps=1).detector + science.setDetector(detector) matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs) # Configure the detection Task @@ -705,6 +726,97 @@ def test_mask_streaks(self): # Check that the entire image was not masked STREAK self.assertFalse(np.all(streakMaskSet)) + def _setup_sattle_tests(self): + noiseLevel = 1. + staticSeed = 1 + fluxLevel = 500 + shared_kwargs = {"seed": staticSeed, "psfSize": 2.4, "fluxLevel": fluxLevel, + "x0": 12345, "y0": 67890} + science, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6, + **shared_kwargs) + science.getInfo().setVisitInfo(makeVisitInfo(id=2)) + detector = DetectorWrapper(numAmps=1).detector + science.setDetector(detector) + matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel / 4, + noiseSeed=7, **shared_kwargs) + difference = science.clone() + + detectionTask = self._setup_detection(doDeblend=True, + badSubtractionRatioThreshold=1., + doSkySources=False, run_sattle=True) + + return science, matchedTemplate, difference, sources, detectionTask + + @mock.patch.dict(os.environ, {"SATTLE_URI_BASE": "fake_host:1234"}) + def test_sattle_not_available(self): + science, matchedTemplate, difference, sources, detectionTask = self._setup_sattle_tests() + + response = MockResponse({"allow_list": []}, 500, "sattle internal error") + with mock.patch('requests.put', return_value=response): + with self.assertRaises(requests.exceptions.HTTPError): + detectionTask.run(science, matchedTemplate, difference, sources, + idFactory=IdFactory.makeSimple()) + + @mock.patch.dict(os.environ, {"SATTLE_URI_BASE": "fake_host:1234"}) + def test_visit_id_not_in_sattle(self): + science, matchedTemplate, difference, sources, detectionTask = self._setup_sattle_tests() + + response = MockResponse({"allow_list": []}, 404, "missing visit cache") + # visit id not in sattle raises + with self.assertRaises(requests.exceptions.HTTPError): + with mock.patch('lsst.ip.diffim.detectAndMeasure.requests.put', + return_value=response): + with mock.patch('lsst.ip.diffim.utils.populate_sattle_visit_cache'): + detectionTask.run(science, matchedTemplate, difference, sources, + idFactory=IdFactory.makeSimple()) + + @mock.patch.dict(os.environ, {"SATTLE_URI_BASE": "fake_host:1234"}) + def test_filter_satellites_some_allowed(self): + science, matchedTemplate, difference, sources, detectionTask = self._setup_sattle_tests() + + allowed_ids = [1, 5] + response = MockResponse({"allow_list": allowed_ids}, 200, "some allowed") + with mock.patch('requests.put', return_value=response): + output = detectionTask.run(science, matchedTemplate, difference, sources, + idFactory=IdFactory.makeSimple()) + + self.assertEqual(len(output.diaSources), 2) + + # Output should be sources 1 and 5 allowed out of 20 + self.assertEqual(set(output.diaSources['id']), set(allowed_ids)) + + @mock.patch.dict(os.environ, {"SATTLE_URI_BASE": "fake_host:1234"}) + def test_filter_satellites_all_allowed(self): + science, matchedTemplate, difference, sources, detectionTask = self._setup_sattle_tests() + + allowed_ids = list(range(1, 21)) + response = MockResponse({"allow_list": allowed_ids}, 200, "all allowed") + # Run detection and check the results + with mock.patch('requests.put', return_value=response): + output = detectionTask.run(science, matchedTemplate, difference, sources, + idFactory=IdFactory.makeSimple()) + + ## Output should be all sources that went in. 20 go in, 20 should come out + self.assertEqual(len(output.diaSources), 20) + + self.assertEqual(set(output.diaSources['id']), set(allowed_ids)) + + @mock.patch.dict(os.environ, {"SATTLE_URI_BASE": "fake_host:1234"}) + def test_filter_satellites_none_allowed(self): + science, matchedTemplate, difference, sources, detectionTask = self._setup_sattle_tests() + + response = MockResponse({"allow_list": []}, 200, "none allowed") + # Run detection and confirm it raises for no diasources + with self.assertRaises(detectAndMeasure.NoDiaSourcesError): + with mock.patch('requests.put', return_value=response): + detectionTask.run(science, matchedTemplate, difference, sources, + idFactory=IdFactory.makeSimple()) + + @mock.patch.dict(os.environ, {"SATTLE_URI_BASE": ""}) + def test_fail_on_sattle_misconfiguration(self): + with self.assertRaises(pexConfig.FieldValidationError): + self._setup_detection(run_sattle=True) + class DetectAndMeasureScoreTest(DetectAndMeasureTestBase, lsst.utils.tests.TestCase): detectionTask = detectAndMeasure.DetectAndMeasureScoreTask @@ -960,6 +1072,7 @@ def test_exclude_mask_detections(self): radius = 2 kwargs = {"seed": staticSeed, "psfSize": 2.4, "fluxLevel": fluxLevel} science, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6, **kwargs) + science.getInfo().setVisitInfo(makeVisitInfo()) matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs) subtractTask = subtractImages.AlardLuptonPreconvolveSubtractTask() @@ -1152,6 +1265,40 @@ def testMergeFootprints(self): self.assertEqual((~result.diaSources["is_negative"]).sum(), 3) +def makeVisitInfo(id=1): + """Return a non-NaN visitInfo.""" + return afwImage.VisitInfo(id=id, + exposureTime=10.01, + darkTime=11.02, + date=dafBase.DateTime(65321.1, dafBase.DateTime.MJD, dafBase.DateTime.TAI), + ut1=12345.1, + era=45.1*geom.degrees, + boresightRaDec=geom.SpherePoint(23.1, 73.2, geom.degrees), + boresightAzAlt=geom.SpherePoint(134.5, 33.3, geom.degrees), + boresightAirmass=1.73, + boresightRotAngle=73.2*geom.degrees, + rotType=afwImage.RotType.SKY, + observatory=Observatory( + 11.1*geom.degrees, 22.2*geom.degrees, 0.333), + weather=Weather(1.1, 2.2, 34.5), + ) + + +class MockResponse: + """Provide a mock for requests.put calls""" + def __init__(self, json_data, status_code, text): + self.json_data = json_data + self.status_code = status_code + self.text = text + + def json(self): + return self.json_data + + def raise_for_status(self): + if self.status_code != 200: + raise requests.exceptions.HTTPError + + def setup_module(module): lsst.utils.tests.init() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..74a8457db --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,45 @@ +# This file is part of ip_diffim. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +import requests +from unittest import mock + +from lsst.ip.diffim.utils import populate_sattle_visit_cache +import lsst.utils.tests + +from test_detectAndMeasure import makeVisitInfo, MockResponse + + +class UtilsTest(lsst.utils.tests.TestCase): + + def test_populate_sattle(self): + response = MockResponse({}, 200, "success") + visit_info = makeVisitInfo() + with mock.patch('requests.put', return_value=response): + populate_sattle_visit_cache(visit_info) + + def test_populate_sattle_raises(self): + response = MockResponse({}, 500, "failure") + visit_info = makeVisitInfo() + with mock.patch('requests.put', return_value=response): + with self.assertRaises(requests.exceptions.HTTPError): + populate_sattle_visit_cache(visit_info)