diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py
new file mode 100644
index 000000000..2f790cf01
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py
@@ -0,0 +1 @@
+from .brightStarStack import *
diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py
new file mode 100644
index 000000000..77f1989f9
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py
@@ -0,0 +1,285 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Stack bright star postage stamp cutouts to produce an extended PSF model."""
+
+__all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"]
+
+import numpy as np
+from lsst.afw.image import ImageF
+from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty
+from lsst.geom import Point2I
+from lsst.meas.algorithms import BrightStarStamps
+from lsst.pex.config import Field, ListField
+from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
+from lsst.pipe.base.connectionTypes import Input, Output
+from lsst.utils.timer import timeMethod
+
+NEIGHBOR_MASK_PLANE = "NEIGHBOR"
+
+
+class BrightStarStackConnections(
+ PipelineTaskConnections,
+ dimensions=("instrument", "band"),
+):
+ """Connections for BrightStarStackTask."""
+
+ brightStarStamps = Input(
+ name="brightStarStamps",
+ storageClass="BrightStarStamps",
+ doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.",
+ dimensions=("visit", "detector"),
+ multiple=True,
+ deferLoad=True,
+ )
+ extendedPsf = Output(
+ name="extendedPsf", # extendedPsfDetector ???
+ storageClass="MaskedImageF", # stamp_imF
+ doc="Extended PSF model, built from stacking bright star cutouts.",
+ dimensions=("band",),
+ )
+
+
+class BrightStarStackConfig(
+ PipelineTaskConfig,
+ pipelineConnections=BrightStarStackConnections,
+):
+ """Configuration parameters for BrightStarStackTask."""
+
+ global_reduced_chi_squared_threshold = Field[float](
+ doc="Threshold for global reduced chi-squared for stamps.",
+ default=5.0,
+ )
+ psf_reduced_chi_squared_threshold = Field[float](
+ doc="Threshold for PSF reduced chi-squared for stamps.",
+ default=50.0,
+ )
+ bright_star_threshold = Field[float](
+ doc="Stars brighter than this magnitude, are considered as bright stars.",
+ default=12.0,
+ )
+ bright_global_reduced_chi_squared_threshold = Field[float](
+ doc="Threshold for global reduced chi-squared for bright star stamps.",
+ default=250.0,
+ )
+ psf_bright_reduced_chi_squared_threshold = Field[float](
+ doc="Threshold for PSF reduced chi-squared for bright star stamps.",
+ default=400.0,
+ )
+
+ bad_mask_planes = ListField[str](
+ doc="Mask planes that identify excluded (masked) pixels.",
+ default=[
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ "SAT",
+ "SUSPECT",
+ "UNMASKEDNAN",
+ NEIGHBOR_MASK_PLANE,
+ ],
+ )
+ stack_type = Field[str](
+ default="MEDIAN",
+ doc="Statistic name to use for stacking (from `~lsst.afw.math.Property`)",
+ )
+ stack_num_sigma_clip = Field[float](
+ doc="Number of sigma to use for clipping when stacking.",
+ default=3.0,
+ )
+ stack_num_iter = Field[int](
+ doc="Number of iterations to use for clipping when stacking.",
+ default=5,
+ )
+ magnitude_bins = ListField[int](
+ doc="Bins of magnitudes for weighting purposes.",
+ default=[20, 19, 18, 17, 16, 15, 13, 10],
+ )
+ subset_stamp_number = ListField[int](
+ doc="Number of stamps per subset to generate stacked "
+ "images for. The length of this parameter must be equal to the length of magnitude_bins minus one.",
+ default=[300, 200, 150, 100, 100, 100, 1],
+ )
+ min_focal_plane_radius = Field[float](
+ doc="Minimum distance to focal plane center in mm. Stars with a focal plane radius smaller than "
+ "this will be omitted.",
+ default=-1.0,
+ )
+ max_focal_plane_radius = Field[float](
+ doc="Maximum distance to focal plane center in mm. Stars with a focal plane radius greater than "
+ "this will be omitted.",
+ default=2000.0,
+ )
+
+
+class BrightStarStackTask(PipelineTask):
+ """Stack bright star postage stamps to produce an extended PSF model."""
+
+ ConfigClass = BrightStarStackConfig
+ _DefaultName = "brightStarStack"
+ config: BrightStarStackConfig
+
+ def __init__(self, initInputs=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def runQuantum(self, butlerQC, inputRefs, outputRefs):
+ inputs = butlerQC.get(inputRefs)
+ output = self.run(**inputs)
+ butlerQC.put(output, outputRefs)
+
+ def _applyStampFit(self, stamp):
+ """Apply fitted stamp components to a single bright star stamp."""
+ stampMI = stamp.stamp_im
+ stamp_bbox = stampMI.getBBox()
+
+ x_grid, y_grid = np.meshgrid(stamp_bbox.getX().arange(), stamp_bbox.getY().arange())
+
+ x_plane = ImageF((x_grid * stamp.gradient_x).astype(np.float32), xy0=stampMI.getXY0())
+ y_plane = ImageF((y_grid * stamp.gradient_y).astype(np.float32), xy0=stampMI.getXY0())
+
+ x_curve = ImageF((x_grid**2 * stamp.curvature_x).astype(np.float32), xy0=stampMI.getXY0())
+ y_curve = ImageF((y_grid**2 * stamp.curvature_y).astype(np.float32), xy0=stampMI.getXY0())
+ xy_curve = ImageF((x_grid * y_grid * stamp.curvature_xy).astype(np.float32), xy0=stampMI.getXY0())
+
+ stampMI -= stamp.pedestal
+ stampMI -= x_plane
+ stampMI -= y_plane
+ stampMI -= x_curve
+ stampMI -= y_curve
+ stampMI -= xy_curve
+ stampMI /= stamp.scale
+
+ @timeMethod
+ def run(
+ self,
+ brightStarStamps: BrightStarStamps,
+ ):
+ """Identify bright stars within an exposure using a reference catalog,
+ extract stamps around each, then preprocess them.
+
+ Bright star preprocessing steps are: shifting, warping and potentially
+ rotating them to the same pixel grid; computing their annular flux,
+ and; normalizing them.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The image from which bright star stamps should be extracted.
+ inputBackground : `~lsst.afw.image.Background`
+ The background model for the input exposure.
+ refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional
+ Loader to find objects within a reference catalog.
+ dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
+ The dataId of the exposure (including detector) that bright stars
+ should be extracted from.
+
+ Returns
+ -------
+ brightStarResults : `~lsst.pipe.base.Struct`
+ Results as a struct with attributes:
+
+ ``brightStarStamps``
+ (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`)
+ """
+ stack_type_property = stringToStatisticsProperty(self.config.stack_type)
+ statistics_control = StatisticsControl(
+ numSigmaClip=self.config.stack_num_sigma_clip,
+ numIter=self.config.stack_num_iter,
+ )
+
+ mag_bins_dict = {}
+ subset_stampMIs = {}
+ self.metadata["psf_star_count"] = {}
+ self.metadata["psf_star_count"]["all"] = 0
+ for i in range(len(self.config.subset_stamp_number)):
+ self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] = 0
+ for stampsDDH in brightStarStamps:
+ stamps = stampsDDH.get()
+ self.metadata["psf_star_count"]["all"] += len(stamps)
+ for stamp in stamps:
+ if stamp.ref_mag >= self.config.bright_star_threshold:
+ global_reduced_chi_squared_threshold = self.config.global_reduced_chi_squared_threshold
+ psf_reduced_chi_squared_threshold = self.config.psf_reduced_chi_squared_threshold
+ else:
+ global_reduced_chi_squared_threshold = (
+ self.config.bright_global_reduced_chi_squared_threshold
+ )
+ psf_reduced_chi_squared_threshold = self.config.psf_bright_reduced_chi_squared_threshold
+ for i in range(len(self.config.subset_stamp_number)):
+ if (
+ stamp.global_reduced_chi_squared > global_reduced_chi_squared_threshold
+ or stamp.psf_reduced_chi_squared > psf_reduced_chi_squared_threshold
+ or stamp.focal_plane_radius < self.config.min_focal_plane_radius
+ or stamp.focal_plane_radius > self.config.max_focal_plane_radius
+ ):
+ continue
+
+ if (
+ stamp.ref_mag < self.config.magnitude_bins[i]
+ and stamp.ref_mag > self.config.magnitude_bins[i + 1]
+ ):
+ self._applyStampFit(stamp)
+ if not self.config.magnitude_bins[i + 1] in mag_bins_dict.keys():
+ mag_bins_dict[self.config.magnitude_bins[i + 1]] = []
+ stampMI = stamp.stamp_im
+ mag_bins_dict[self.config.magnitude_bins[i + 1]].append(stampMI)
+ bad_mask_bit_mask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes)
+ statistics_control.setAndMask(bad_mask_bit_mask)
+ if (
+ len(mag_bins_dict[self.config.magnitude_bins[i + 1]])
+ == self.config.subset_stamp_number[i]
+ ):
+ if self.config.magnitude_bins[i + 1] not in subset_stampMIs.keys():
+ subset_stampMIs[self.config.magnitude_bins[i + 1]] = []
+ subset_stampMIs[self.config.magnitude_bins[i + 1]].append(
+ statisticsStack(
+ mag_bins_dict[self.config.magnitude_bins[i + 1]],
+ stack_type_property,
+ statistics_control,
+ )
+ )
+ self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] += len(
+ mag_bins_dict[self.config.magnitude_bins[i + 1]]
+ )
+ mag_bins_dict[self.config.magnitude_bins[i + 1]] = []
+
+ for key in mag_bins_dict.keys():
+ if key not in subset_stampMIs.keys():
+ subset_stampMIs[key] = []
+ subset_stampMIs[key].append(
+ statisticsStack(mag_bins_dict[key], stack_type_property, statistics_control)
+ )
+ self.metadata["psf_star_count"][str(key)] += len(mag_bins_dict[key])
+
+ final_subset_stampMIs = []
+ for key in subset_stampMIs.keys():
+ final_subset_stampMIs.extend(subset_stampMIs[key])
+ bad_mask_bit_mask = final_subset_stampMIs[0].mask.getPlaneBitMask(self.config.bad_mask_planes)
+ statistics_control.setAndMask(bad_mask_bit_mask)
+ extendedPsfMI = statisticsStack(final_subset_stampMIs, stack_type_property, statistics_control)
+
+ extendedPsfExtent = extendedPsfMI.getBBox().getDimensions()
+ extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2))
+ extendedPsfMI.setXY0(extendedPsfOrigin)
+
+ return Struct(extendedPsf=extendedPsfMI)
diff --git a/tests/test_brightStarStack.py b/tests/test_brightStarStack.py
new file mode 100644
index 000000000..073464013
--- /dev/null
+++ b/tests/test_brightStarStack.py
@@ -0,0 +1,226 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+
+import lsst.afw.image
+import lsst.utils.tests
+import numpy as np
+from lsst.afw.image import ImageF, MaskedImageF
+from lsst.pipe.tasks.brightStarSubtraction import BrightStarStackConfig, BrightStarStackTask
+
+
+# Mock class to simulate the DeferredDatasetHandle (DDH) behavior
+class MockHandle:
+ def __init__(self, content):
+ self.content = content
+
+ def get(self):
+ return self.content
+
+
+# Mock class to simulate a single BrightStarStamp
+class MockStamp:
+ def __init__(self, stamp_im, mag, fit_params, stats):
+ self.stamp_im = stamp_im
+ self.ref_mag = mag
+
+ # Unpack fit parameters
+ self.scale = fit_params.get("scale", 1.0)
+ self.pedestal = fit_params.get("pedestal", 0.0)
+ self.gradient_x = fit_params.get("gradient_x", 0.0)
+ self.gradient_y = fit_params.get("gradient_y", 0.0)
+ self.curvature_x = fit_params.get("curvature_x", 0.0)
+ self.curvature_y = fit_params.get("curvature_y", 0.0)
+ self.curvature_xy = fit_params.get("curvature_xy", 0.0)
+
+ # Unpack statistics for filtering
+ self.global_reduced_chi_squared = stats.get("global_chi2", 1.0)
+ self.psf_reduced_chi_squared = stats.get("psf_chi2", 1.0)
+ self.bright_global_reduced_chi_squared = stats.get("bright_global_chi2", 1.0)
+ self.psf_bright_reduced_chi_squared = stats.get("psf_bright_chi2", 1.0)
+ self.bright_star_threshold = stats.get("brights_threshold", 100.0)
+ self.focal_plane_radius = stats.get("fp_radius", 100.0)
+
+
+class BrightStarStackTestCase(lsst.utils.tests.TestCase):
+ def setUp(self):
+ # Define fit values
+ self.scale = 10.0
+ self.pedestal = 50.0
+ self.x_gradient = 0.5
+ self.y_gradient = -0.5
+ self.curvature_x = 0.01
+ self.curvature_y = 0.01
+ self.curvature_xy = 0.005
+
+ self.fit_params = {
+ "scale": self.scale,
+ "pedestal": self.pedestal,
+ "gradient_x": self.x_gradient,
+ "gradient_y": self.y_gradient,
+ "curvature_x": self.curvature_x,
+ "curvature_y": self.curvature_y,
+ "curvature_xy": self.curvature_xy,
+ }
+
+ # Create the "Clean" PSF (a simple Gaussian)
+ self.dim = 51
+ x_coords = np.linspace(-25, 25, self.dim)
+ y_coords = np.linspace(-25, 25, self.dim)
+ x_grid, y_grid = np.meshgrid(x_coords, y_coords)
+
+ sigma = 5.0
+ dist_sq = x_grid**2 + y_grid**2
+ self.clean_array = np.exp(-dist_sq / (2 * sigma**2))
+
+ # Create the "star" Image (What the task receives)
+ # Apply scaling
+ star_array = self.clean_array * self.scale
+
+ # Add background terms
+ x_indices, y_indices = np.meshgrid(np.arange(self.dim), np.arange(self.dim))
+
+ star_array += self.pedestal
+ star_array += x_indices * self.x_gradient
+ star_array += y_indices * self.y_gradient
+ star_array += (x_indices**2) * self.curvature_x
+ star_array += (y_indices**2) * self.curvature_y
+ star_array += (x_indices * y_indices) * self.curvature_xy
+
+ # Create MaskedImage
+ stampIm = ImageF(star_array.astype(np.float32))
+ stampVa = ImageF(stampIm.getBBox(), 1.0)
+ self.stampMI = MaskedImageF(image=stampIm, variance=stampVa)
+
+ # Initialize the mask planes required
+ badMaskPlanes = [
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ "SAT",
+ "SUSPECT",
+ "UNMASKEDNAN",
+ "NEIGHBOR",
+ ]
+ _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes]
+
+ def test_applyStampFit(self):
+ """Test that _applyStampFit correctly removes background and normalizes."""
+ config = BrightStarStackConfig()
+ task = BrightStarStackTask(config=config)
+
+ # Create a mock stamp
+ stamp_mi_copy = self.stampMI.clone()
+ mock_stamp = MockStamp(stamp_mi_copy, mag=10.0, fit_params=self.fit_params, stats={})
+
+ # Run the method
+ task._applyStampFit(mock_stamp)
+
+ # The result should be the clean array (normalized to scale 1.0)
+ result_array = mock_stamp.stamp_im.image.array
+
+ # Allow for small floating point discrepancies
+ np.testing.assert_allclose(result_array, self.clean_array, atol=1e-5)
+
+ def test_run(self):
+ """Test the full run method: filtering, binning, and stacking."""
+ config = BrightStarStackConfig()
+ # Set config to ensure our test stamps are included
+ config.magnitude_bins = [11, 9]
+ config.subset_stamp_number = [1]
+ config.stack_type = "MEDIAN"
+
+ task = BrightStarStackTask(config=config)
+
+ valid_stats = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 100.0}
+ invalid_stats = {"global_chi2": 1e9, "psf_chi2": 1e9, "fp_radius": 100.0}
+
+ stamp1 = MockStamp(self.stampMI.clone(), mag=10.0, fit_params=self.fit_params, stats=valid_stats)
+ stamp2 = MockStamp(self.stampMI.clone(), mag=10.0, fit_params=self.fit_params, stats=valid_stats)
+
+ # This stamp should be ignored
+ bad_stamp = MockStamp(self.stampMI.clone(), mag=10.0, fit_params=self.fit_params, stats=invalid_stats)
+
+ # Create mock input structure
+ # brightStarStamps is a list of handles
+ input_stamps = [MockHandle([stamp1, bad_stamp]), MockHandle([stamp2])]
+
+ result = task.run(brightStarStamps=input_stamps)
+
+ # Verify output exists
+ self.assertIsNotNone(result.extendedPsf)
+
+ # Verify output dimensions match input
+ self.assertEqual(result.extendedPsf.getDimensions(), self.stampMI.getDimensions())
+
+ # Verify the calculation
+ # Since we stacked identical "clean" stamps (after fit application),
+ # the result should match self.clean_array
+ result_array = result.extendedPsf.image.array
+ np.testing.assert_allclose(result_array, self.clean_array, atol=1e-5)
+
+ def test_filtering_logic(self):
+ """Test that stamps outside focal plane radius or thresholds are skipped."""
+ config = BrightStarStackConfig()
+ config.min_focal_plane_radius = 50.0
+ config.max_focal_plane_radius = 150.0
+ config.global_reduced_chi_squared_threshold = 5.0
+ config.magnitude_bins = [15, 11, 9]
+ config.subset_stamp_number = [100, 1]
+
+ task = BrightStarStackTask(config=config)
+
+ good_stats = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 100.0}
+ bad_radius_low = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 10.0}
+ bad_radius_high = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 2000.0}
+ bad_chi2 = {"global_chi2": 100.0, "psf_chi2": 1.0, "fp_radius": 100.0}
+
+ stamps = [
+ MockStamp(self.stampMI.clone(), 10.0, self.fit_params, good_stats),
+ MockStamp(self.stampMI.clone(), 10.0, self.fit_params, bad_radius_low),
+ MockStamp(self.stampMI.clone(), 10.0, self.fit_params, bad_radius_high),
+ MockStamp(self.stampMI.clone(), 10.0, self.fit_params, bad_chi2),
+ ]
+
+ input_stamps = [MockHandle(stamps)]
+ task.run(brightStarStamps=input_stamps)
+
+ bin_key = "9" # Based on config magnitude_bins=[11, 9], the lower bound is 9
+
+ self.assertEqual(task.metadata["psf_star_count"]["all"], 4)
+
+ self.assertEqual(task.metadata["psf_star_count"][bin_key], 2)
+
+
+def setup_module(module):
+ lsst.utils.tests.init()
+
+
+class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()