Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions python/lsst/ip/diffim/detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lsst.ip.diffim.utils import (evaluateMaskFraction, computeDifferenceImageMetrics,
populate_sattle_visit_cache)
from lsst.meas.algorithms import SkyObjectsTask, SourceDetectionTask, SetPrimaryFlagsTask, MaskStreaksTask
from lsst.meas.algorithms import FindGlintTrailsTask
from lsst.meas.base import ForcedMeasurementTask, ApplyApCorrTask, DetectorVisitIdGeneratorConfig
import lsst.meas.deblender
import lsst.meas.extensions.trailedSources # noqa: F401
Expand Down Expand Up @@ -138,13 +139,21 @@ class DetectAndMeasureConnections(pipeBase.PipelineTaskConnections,
dimensions=("instrument", "visit", "detector"),
name="{fakesType}{coaddName}Diff_streaks",
)
glintTrailInfo = pipeBase.connectionTypes.Output(
doc='Dict of fit parameters for glint trails in the catalog.',
storageClass="ArrowNumpyDict",
dimensions=("instrument", "visit", "detector"),
name="trailed_glints",
)

def __init__(self, *, config):
super().__init__(config=config)
if not (self.config.writeStreakInfo and self.config.doMaskStreaks):
self.outputs.remove("maskedStreaks")
if not (self.config.doSubtractBackground and self.config.doWriteBackground):
self.outputs.remove("differenceBackground")
if not (self.config.writeGlintInfo):
self.outputs.remove("glintTrailInfo")


class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
Expand Down Expand Up @@ -260,6 +269,15 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
doc="Record the parameters of any detected streaks. For LSST, this should be turned off except for "
"development work."
)
findGlints = pexConfig.ConfigurableField(
target=FindGlintTrailsTask,
doc="Subtask for finding glint trails, usually caused by satellites or debris."
)
writeGlintInfo = pexConfig.Field(
dtype=bool,
default=True,
doc="Record the parameters of any detected glint trails."
)
setPrimaryFlags = pexConfig.ConfigurableField(
target=SetPrimaryFlagsTask,
doc="Task to add isPrimary and deblending-related flags to the catalog."
Expand Down Expand Up @@ -455,6 +473,8 @@ def __init__(self, **kwargs):
if self.config.doMaskStreaks:
self.makeSubtask("maskStreaks")
self.makeSubtask("streakDetection")
self.makeSubtask("findGlints")
self.schema.addField("glint_trail", "Flag", "DiaSource is part of a glint trail.")

# To get the "merge_*" fields in the schema; have to re-initialize
# this later, once we have a peak schema post-detection.
Expand Down Expand Up @@ -492,6 +512,7 @@ def runQuantum(self, butlerQC: pipeBase.QuantumContext,
measurementResults.subtractedMeasuredExposure,
measurementResults.diaSources,
measurementResults.maskedStreaks,
measurementResults.glintTrailInfo,
log=self.log
)
butlerQC.put(measurementResults, outputRefs)
Expand Down Expand Up @@ -715,6 +736,13 @@ def processResults(self, science, matchedTemplate, difference, sources, idFactor
initialDiaSources = initialDiaSources.copy(deep=True)

self.measureDiaSources(initialDiaSources, science, difference, matchedTemplate)

# Add a column for glint trail diaSources, but do not remove them
initialDiaSources, trail_parameters = self._find_glint_trails(initialDiaSources)
if self.config.writeGlintInfo:
measurementResults.mergeItems(trail_parameters, 'glintTrailInfo')

# Remove unphysical diaSources per config.badSourceFlags
diaSources = self._removeBadSources(initialDiaSources)

if self.config.run_sattle:
Expand Down Expand Up @@ -835,6 +863,39 @@ def _removeBadSources(self, diaSources):
self.log.info("Removed %d unphysical sources.", nBadTotal)
return diaSources[selector].copy(deep=True)

def _find_glint_trails(self, diaSources):
"""Define a new flag column for diaSources that are in a glint trail.

Parameters
----------
diaSources : `lsst.afw.table.SourceCatalog`
The catalog of detected sources.

Returns
-------
diaSources : `lsst.afw.table.SourceCatalog`
The updated catalog of detected sources, with a new bool column
called 'glint_trail' added.

trail_parameters : `dict`
Parameters of all the trails that were found.
"""
trailed_glints = self.findGlints.run(diaSources)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trailed_glints = self.findGlints.run(diaSources)
glint_trails = self.findGlints.run(diaSources)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new Butler dataset, not the new diaSource column, so I am leaving the name as-is.

glint_mask = [True if id in trailed_glints.trailed_ids else False for id in diaSources['id']]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
glint_mask = [True if id in trailed_glints.trailed_ids else False for id in diaSources['id']]
glint_mask = [id in glint_trails.trailed_ids for id in diaSources['id']]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this!

diaSources['glint_trail'] = np.array(glint_mask)

slopes = np.array([trail.slope for trail in trailed_glints.parameters])
intercepts = np.array([trail.intercept for trail in trailed_glints.parameters])
stderrs = np.array([trail.stderr for trail in trailed_glints.parameters])
lengths = np.array([trail.length for trail in trailed_glints.parameters])
angles = np.array([trail.angle for trail in trailed_glints.parameters])
parameters = {'slopes': slopes, 'intercepts': intercepts, 'stderrs': stderrs, 'lengths': lengths,
'angles': angles}

trail_parameters = pipeBase.Struct(glintTrailInfo=parameters)

return diaSources, trail_parameters

def addSkySources(self, diaSources, mask, seed,
subtask=None):
"""Add sources in empty regions of the difference image
Expand Down
30 changes: 29 additions & 1 deletion tests/test_detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def test_filter_satellites_all_allowed(self):
output = detectionTask.run(science, matchedTemplate, difference, sources,
idFactory=IdFactory.makeSimple())

## Output should be all sources that went in. 20 go in, 20 should come out
# Output should be all sources that went in. 20 go in, 20 should come out
self.assertEqual(len(output.diaSources), 20)

self.assertEqual(set(output.diaSources['id']), set(allowed_ids))
Expand All @@ -817,6 +817,34 @@ def test_fail_on_sattle_misconfiguration(self):
with self.assertRaises(pexConfig.FieldValidationError):
self._setup_detection(run_sattle=True)

def test_trailed_glints(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to test that writeGlintInfo provides the expected outputs as well.

"""Test that the glint_trail column works, and that
the trailed_glints output contains the expected information.
"""
noiseLevel = 1.
staticSeed = 1
diffim, diaSources = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel, noiseSeed=6)
self._check_values(diaSources['glint_trail'])

# Run detection and return the output Struct so we can check it
def _detection_wrapper(diffim, diaSources):
detectionTask = self._setup_detection()
scienceBase, sources = makeTestImage(noiseLevel=noiseLevel, noiseSeed=6)
matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7)
science = scienceBase.clone()
science.maskedImage -= diffim.maskedImage
difference = science.clone()
difference.maskedImage -= matchedTemplate.maskedImage
output = detectionTask.run(science, matchedTemplate, difference, sources)
return output

output = _detection_wrapper(diffim, diaSources)
self.assertTrue('slopes' in output.glintTrailInfo)
self.assertTrue('intercepts' in output.glintTrailInfo)
self.assertTrue('stderrs' in output.glintTrailInfo)
self.assertTrue('lengths' in output.glintTrailInfo)
self.assertTrue('angles' in output.glintTrailInfo)


class DetectAndMeasureScoreTest(DetectAndMeasureTestBase, lsst.utils.tests.TestCase):
detectionTask = detectAndMeasure.DetectAndMeasureScoreTask
Expand Down
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,7 @@ def _makeTruthSchema():
schema.getAliasMap().set("slot_CalibFlux", "truth")
schema.getAliasMap().set("slot_ApFlux", "truth")
schema.getAliasMap().set("slot_PsfFlux", "truth")
schema.addField("glint_trail", "Flag", "testing flag.")
return keys, schema


Expand Down