diff --git a/doc/changes/DM-47738.feature.rst b/doc/changes/DM-47738.feature.rst
deleted file mode 100644
index d96cacf..0000000
--- a/doc/changes/DM-47738.feature.rst
+++ /dev/null
@@ -1,3 +0,0 @@
-* Added ``Image.trimmed`` method to remove data below a threshold from an Image.
-* Added ``Image.at`` method to extract a single pixel from an Image.
-* Added slicing for ``Observation`` to slice along spectral or spatial dimension
\ No newline at end of file
diff --git a/doc/changes/DM-49537.md b/doc/changes/DM-49537.md
deleted file mode 100644
index 33d268a..0000000
--- a/doc/changes/DM-49537.md
+++ /dev/null
@@ -1 +0,0 @@
-Improved the storage of sources and components using a registry. This nmakes it easier for end users to persist custom component types.
\ No newline at end of file
diff --git a/doc/lsst.scarlet.lite/changes.rst b/doc/lsst.scarlet.lite/changes.rst
new file mode 100644
index 0000000..2c2f5bf
--- /dev/null
+++ b/doc/lsst.scarlet.lite/changes.rst
@@ -0,0 +1,38 @@
+.. _lsst.scarlet.lite-changes:
+
+=================
+v30.0.0 Changes
+=================
+
+Improved Slicing
+----------------
+
+In v29.0 slicing was only supported for the ``Image`` class. In ``v30.0`` slicing has been extended to ``Blend``, ``Source``, ``Component``, and ``Observation`` classes. This allows uers to use a subset of bands or change the band order of each of these classes. ``Observation`` can be sliced along the spatial dimension as well.
+
+New Image methods
+-----------------
+- ``Image.trimmed`` method added to remove data below a threshold from an Image.
+- ``Image.at`` method added to extract a single pixel from an Image.
+
+Serialization Improvments
+-------------------------
+Serialization received a major update in ``v30.0`` to improve performance and usability. The major upgrade is a ``Migration Registry`` that registers all scarlet serializable classes and allows users to register their own custom classes along with function to migrate between different versions of the class.
+This allows scarlet to automatically handle versioning of serialized objects and migrate them to the latest version when deserializing.
+
+As part of this update the classes used for serialization were expanded to included base classes such as ``BlendBaseData``, ``SourceBaseData``, and ``ComponentBaseData`` to make it easier for users to extend serialization to their own custom classes.
+
+Copying and Deep Copying
+------------------------
+To support the serialization improvements, ``__copy__`` and ``__deepcopy__`` methods were added to nearly all scarlet lite classes. This allows users to create copies of scarlet objects using the standard library ``copy`` and ``deepcopy`` modules.
+
+Initialization Changes
+----------------------
+Initialization has a few changes to make it both more customizable and more standardized
+- The ineffective ``FactorizedWaveletInitialization`` was deprecated and the ``FactorizedChirInitialization`` was changed to ``FactorizedInitialization`` as testing has shown that it is a superior algorithm.
+- A large number of new parameters were added to ``FactorizedInitialization`` to make it easier for users to initialize sources using more tunable constraints and detection images that may differ from the observation image.
+
+Other Changes
+-------------
+- The detection algorithms have been updated and while not ready for production users should see improved performance.
+- A more customizable ``conserve_flux`` function was added to ``scarlet``
+- Positions are rounded instead of truncated during initialization to remove a bias from inaccurate sources positions.
diff --git a/doc/lsst.scarlet.lite/index.rst b/doc/lsst.scarlet.lite/index.rst
index c510f93..201227c 100644
--- a/doc/lsst.scarlet.lite/index.rst
+++ b/doc/lsst.scarlet.lite/index.rst
@@ -20,6 +20,7 @@ toctree linking to topics related to using the module's APIs.
getting_started
detection
+ changes
.. _lsst.scarlet.lite-contributing:
diff --git a/pyproject.toml b/pyproject.toml
index a3041be..8d181df 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,7 +43,7 @@ test = [
"pytest >= 3.2",
]
yaml = ["pyyaml >= 5.1"]
-plotting = ["matplotlib", "astropy < 7"]
+plotting = ["matplotlib", "astropy >= 6.1"]
[tool.setuptools.packages.find]
where = ["python"]
diff --git a/python/lsst/scarlet/lite/bbox.py b/python/lsst/scarlet/lite/bbox.py
index f861673..73cb04c 100644
--- a/python/lsst/scarlet/lite/bbox.py
+++ b/python/lsst/scarlet/lite/bbox.py
@@ -23,7 +23,8 @@
__all__ = ["Box", "overlapped_slices"]
-from typing import Sequence, cast
+from copy import deepcopy
+from typing import Any, Sequence, cast
import numpy as np
@@ -462,6 +463,15 @@ def __matmul__(self, bbox: Box) -> Box:
result = Box.from_bounds(*bounds)
return result
+ def __deepcopy__(self, memo: dict[int, Any]) -> Box:
+ """Deep copy of the box"""
+ my_id = id(self)
+ if my_id in memo:
+ return memo[my_id]
+ result = Box(deepcopy(self.shape), origin=deepcopy(self.origin))
+ memo[my_id] = result
+ return result
+
def __copy__(self) -> Box:
"""Copy of the box"""
return Box(self.shape, origin=self.origin)
diff --git a/python/lsst/scarlet/lite/blend.py b/python/lsst/scarlet/lite/blend.py
index e8610ab..4919585 100644
--- a/python/lsst/scarlet/lite/blend.py
+++ b/python/lsst/scarlet/lite/blend.py
@@ -24,7 +24,8 @@
__all__ = ["Blend"]
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, Callable, Sequence, cast
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, Callable, Self, Sequence, cast
import numpy as np
@@ -32,7 +33,7 @@
from .component import Component, FactorizedComponent
from .image import Image
from .observation import Observation
-from .source import Source
+from .source import Source, SourceBase
if TYPE_CHECKING:
from .io import ScarletBlendData, ScarletSourceBaseData
@@ -57,7 +58,7 @@ class BlendBase(ABC):
Additional metadata to store with the blend.
"""
- sources: list[Source]
+ sources: Sequence[SourceBase]
observation: Observation
metadata: dict | None
@@ -80,6 +81,72 @@ def components(self) -> list[Component]:
"""
return [c for src in self.sources for c in src.components]
+ @abstractmethod
+ def __getitem__(self, indices: Any) -> Self:
+ """Get a sub-blend corresponding to the given indices.
+
+ Parameters
+ ----------
+ indices :
+ The indices to use to slice the blend.
+
+ Returns
+ -------
+ sub_blend :
+ A new `BlendBase` instance containing only data from the
+ specified bands in the specified order.
+
+ Raises
+ ------
+ IndexError :
+ If the indices contain bands not included in the original
+ blend or any spatial indices are given.
+ """
+
+ @abstractmethod
+ def __copy__(self) -> Self:
+ """Create a copy of this blend.
+
+ Returns
+ -------
+ blend : BlendBase
+ A new blend that is a copy of this one.
+ """
+
+ @abstractmethod
+ def __deepcopy__(self, memo: dict[int, Any]) -> Self:
+ """Create a deep copy of this blend.
+
+ Parameters
+ ----------
+ memo : dict[int, Any]
+ A memoization dictionary used by `copy.deepcopy`.
+
+ Returns
+ -------
+ blend : BlendBase
+ A new blend that is a deep copy of this one.
+ """
+
+ def copy(self, deep: bool = False) -> Self:
+ """Create a copy of this blend.
+
+ Parameters
+ ----------
+ deep :
+ If `True`, a deep copy is made. If `False`, a shallow copy is made.
+ Default is `False`.
+
+ Returns
+ -------
+ blend : Self
+ A new blend that is a copy of this one.
+ """
+ if deep:
+ return self.__deepcopy__({})
+ else:
+ return self.__copy__()
+
@abstractmethod
def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
"""Generate a model of the entire blend.
@@ -128,6 +195,8 @@ class Blend(BlendBase):
Additional metadata to store with the blend.
"""
+ sources: list[Source]
+
def __init__(self, sources: Sequence[Source], observation: Observation, metadata: dict | None = None):
self.sources = list(sources)
self.observation = observation
@@ -433,3 +502,69 @@ def to_data(self) -> ScarletBlendData:
)
return blend_data
+
+ def __getitem__(self, indices: Any) -> Blend:
+ """Get a sub-blend corresponding to the given indices.
+
+ Parameters
+ ----------
+ indices :
+ The indices to use to slice the blend.
+
+ Returns
+ -------
+ blend :
+ A new `Blend` instance containing only data from the
+ specified bands in the specified order.
+
+ Raises
+ ------
+ IndexError :
+ If the indices contain bands not included in the original
+ blend or a bounding box is given.
+ """
+ return Blend(
+ sources=[src[indices] for src in self.sources],
+ observation=self.observation[indices],
+ metadata=self.metadata,
+ )
+
+ def __copy__(self) -> Blend:
+ """Create a copy of this blend.
+
+ Returns
+ -------
+ blend : Blend
+ A new blend that is a copy of this one.
+ """
+ return Blend(sources=self.sources, observation=self.observation, metadata=self.metadata)
+
+ def __deepcopy__(self, memo: dict[int, Any]) -> Blend:
+ """Create a deep copy of this blend.
+
+ Parameters
+ ----------
+ memo : dict[int, Any]
+ A memoization dictionary used by `copy.deepcopy`.
+
+ Returns
+ -------
+ blend : Blend
+ A new blend that is a deep copy of this one.
+ """
+ # Check if already copied
+ if id(self) in memo:
+ return memo[id(self)]
+
+ # Create placeholder and add to memo FIRST
+ blend = Blend.__new__(Blend)
+ memo[id(self)] = blend
+
+ # Now safely initialize the placeholder with deepcopied arguments
+ blend.__init__( # type: ignore[misc]
+ sources=[deepcopy(src, memo) for src in self.sources],
+ observation=deepcopy(self.observation, memo),
+ metadata=deepcopy(self.metadata, memo),
+ )
+
+ return blend
diff --git a/python/lsst/scarlet/lite/component.py b/python/lsst/scarlet/lite/component.py
index 81f109d..0e02d83 100644
--- a/python/lsst/scarlet/lite/component.py
+++ b/python/lsst/scarlet/lite/component.py
@@ -18,11 +18,13 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-
from __future__ import annotations
+from copy import deepcopy
+
__all__ = [
"Component",
+ "CubeComponent",
"FactorizedComponent",
"default_fista_parameterization",
"default_adaprox_parameterization",
@@ -30,7 +32,7 @@
from abc import ABC, abstractmethod
from functools import partial
-from typing import TYPE_CHECKING, Callable, cast
+from typing import TYPE_CHECKING, Any, Callable, cast
import numpy as np
@@ -38,9 +40,14 @@
from .image import Image
from .operators import Monotonicity, prox_uncentered_symmetry
from .parameters import AdaproxParameter, FistaParameter, Parameter, parameter, relative_step
+from .utils import convert_indices
if TYPE_CHECKING:
- from .io import ScarletComponentBaseData
+ from .io import ScarletComponentBaseData, ScarletCubeComponentData
+
+import logging
+
+Logger = logging.getLogger(__name__)
class Component(ABC):
@@ -126,6 +133,64 @@ def to_data(self) -> ScarletComponentBaseData:
The data object containing the component information
"""
+ @abstractmethod
+ def __getitem__(self, indices: Any) -> Component:
+ """Get a sub-component corresponding to the given indices.
+
+ Parameters
+ ----------
+ indices: Any
+ The indices to use to slice the component model.
+
+ Returns
+ -------
+ sub_component: Component
+ A new component that is a sub-component of this one.
+
+ Raises
+ ------
+ IndexError :
+ If the index includes a ``Box`` or spatial indices.
+ """
+
+ @abstractmethod
+ def __copy__(self) -> Component:
+ """Create a copy of this component.
+
+ Returns
+ -------
+ component : Component
+ A new component that is a copy of this one.
+ """
+
+ @abstractmethod
+ def __deepcopy__(self, memo: dict[int, Any]) -> Component:
+ """Create a deep copy of this component.
+
+ Returns
+ -------
+ component : Component
+ A new component that is a deep copy of this one.
+ """
+
+ def copy(self, deep: bool = False) -> Component:
+ """Create a copy of this component.
+
+ Parameters
+ ----------
+ deep : bool, optional
+ If `True`, a deep copy is made. If `False`, a shallow copy is made.
+ Default is `False`.
+
+ Returns
+ -------
+ component : Component
+ A new component that is a copy of this one.
+ """
+ if deep:
+ return self.__deepcopy__({})
+ return self.__copy__()
+
class FactorizedComponent(Component):
"""A component that can be factorized into spectrum and morphology
@@ -395,6 +460,234 @@ def __str__(self):
def __repr__(self):
return self.__str__()
+ def __getitem__(self, indices: Any) -> FactorizedComponent:
+ """Get a sub-component corresponding to the given indices.
+
+ Parameters
+ ----------
+ indices: Any
+ The indices to use to slice the component model.
+
+ Returns
+ -------
+ component: FactorizedComponent
+ A new component that is a sub-component of this one.
+
+ Raises
+ ------
+ IndexError :
+ If the index includes a ``Box`` or spatial indices.
+ """
+ # Convert the band indices into numerical indices
+ band_indices = convert_indices(self.bands, indices)
+ if isinstance(band_indices, slice):
+ bands = self.bands[band_indices]
+ else:
+ bands = tuple(self.bands[i] for i in band_indices)
+
+ # Slice the spectrum
+ spectrum = self._spectrum.x[band_indices,]
+
+ return FactorizedComponent(
+ bands=bands,
+ spectrum=spectrum,
+ morph=self.morph,
+ bbox=self.bbox,
+ peak=self.peak,
+ bg_rms=self.bg_rms,
+ bg_thresh=self.bg_thresh,
+ floor=self.floor,
+ monotonicity=self.monotonicity,
+ padding=self.padding,
+ is_symmetric=self.is_symmetric,
+ )
+
+ def __deepcopy__(self, memo: dict[int, Any]) -> FactorizedComponent:
+ """Create a deep copy of this component.
+
+ Parameters
+ ----------
+ memo: dict[int, Any]
+ The memoization dictionary used by `copy.deepcopy`.
+
+ Returns
+ -------
+ component : FactorizedComponent
+ A new component that is a deep copy of this one.
+ """
+ # Check if already copied
+ if id(self) in memo:
+ return memo[id(self)]
+
+ # Create placeholder and add to memo FIRST
+ component = FactorizedComponent.__new__(FactorizedComponent)
+ memo[id(self)] = component
+
+ # Now safely initialize the placeholder with deepcopied arguments
+ component.__init__( # type: ignore[misc]
+ bands=deepcopy(self.bands, memo),
+ spectrum=deepcopy(self.spectrum, memo),
+ morph=deepcopy(self.morph, memo),
+ bbox=deepcopy(self.bbox, memo),
+ peak=deepcopy(self.peak, memo),
+ bg_rms=deepcopy(self.bg_rms, memo),
+ bg_thresh=self.bg_thresh,
+ floor=self.floor,
+ monotonicity=deepcopy(self.monotonicity, memo),
+ padding=self.padding,
+ is_symmetric=self.is_symmetric,
+ )
+ return component
+
+ def __copy__(self) -> FactorizedComponent:
+ """Create a copy of this component.
+
+ Returns
+ -------
+ component : FactorizedComponent
+ A new component that is a shallow copy of this one.
+ """
+ return FactorizedComponent(
+ bands=self.bands,
+ spectrum=self.spectrum,
+ morph=self.morph,
+ bbox=self.bbox,
+ peak=self.peak,
+ bg_rms=self.bg_rms,
+ bg_thresh=self.bg_thresh,
+ floor=self.floor,
+ monotonicity=self.monotonicity,
+ padding=self.padding,
+ is_symmetric=self.is_symmetric,
+ )
+
+
+class CubeComponent(Component):
+ """Dummy component for a component cube.
+
+ This is duck-typed to a `lsst.scarlet.lite.Component` in order to
+ generate a model from the component but it is currently not functional
+ in that it cannot be optimized, only persisted and loaded.
+
+ If scarlet lite ever implements a component as a data cube,
+ this class can be removed.
+ """
+
+ def __init__(self, model: Image, peak: tuple[int, int]):
+ """Initialization
+
+ Parameters
+ ----------
+ bands :
+ model :
+ The 3D (bands, y, x) model of the component.
+ peak :
+ The `(y, x)` peak of the component.
+ bbox :
+ The bounding box of the component.
+ """
+ super().__init__(model.bands, model.bbox)
+ self._model = model
+ self.peak = peak
+
+ def get_model(self) -> Image:
+ """Generate the model for the source
+
+ Returns
+ -------
+ model :
+ The model as a 3D `(band, y, x)` array.
+ """
+ return self._model
+
+ def resize(self, model_box: Box) -> bool:
+ """Resize the component if needed and return whether it was resized"""
+ Logger.warning("CubeComponent does not support resizing")
+ return False
+
+ def update(self, it: int, input_grad: np.ndarray) -> None:
+ """Implementation of unused abstract method"""
+ Logger.warning("CubeComponent does not support updates")
+
+ def parameterize(self, parameterization: Callable) -> None:
+ """Implementation of unused abstract method"""
+ Logger.warning("CubeComponent does not support parameterization")
+
+ def to_data(self) -> ScarletCubeComponentData:
+ """Convert the component to persistable ScarletComponentData
+
+ Returns
+ -------
+ component_data: ScarletComponentData
+ The data object containing the component information
+ """
+ from .io import ScarletCubeComponentData
+
+ return ScarletCubeComponentData(
+ origin=self.bbox.origin, # type: ignore
+ peak=self.peak, # type: ignore
+ model=self.get_model().data,
+ )
+
+ def __getitem__(self, indices: Any) -> CubeComponent:
+ """Get a sub-component corresponding to the given indices.
+
+ Parameters
+ ----------
+ indices :
+ The indices to select.
+ Returns
+ -------
+ sub_component :
+ A new component that is a sub-component of this one.
+ """
+ band_indices = convert_indices(self.bands, indices)
+ if isinstance(band_indices, slice):
+ bands = self.bands[band_indices]
+ else:
+ bands = tuple(self.bands[i] for i in band_indices)
+
+ data = self.get_model()._data[band_indices,]
+ model = Image(data=data, bands=bands, yx0=cast(tuple[int, int], self.bbox.origin))
+ return CubeComponent(model=model, peak=self.peak)
+
+ def __copy__(self) -> CubeComponent:
+ """Create a copy of this component.
+
+ Returns
+ -------
+ component : ComponentCube
+ A new component that is a shallow copy of this one.
+ """
+ return CubeComponent(model=self._model, peak=self.peak)
+
+ def __deepcopy__(self, memo: dict[int, Any]) -> CubeComponent:
+ """Create a deep copy of this component.
+
+ Parameters
+ ----------
+ memo: dict[int, Any]
+ The memoization dictionary used by `copy.deepcopy`.
+
+ Returns
+ -------
+ component : ComponentCube
+ A new component that is a deep copy of this one.
+ """
+ if id(self) in memo:
+ return memo[id(self)]
+
+ # Create placeholder and add to memo FIRST
+ component = CubeComponent.__new__(CubeComponent)
+ memo[id(self)] = component
+
+ # Now safely initialize the placeholder with deepcopied arguments
+ component.__init__( # type: ignore[misc]
+ model=self._model.copy(),
+ peak=self.peak,
+ )
+ return component
+
def default_fista_parameterization(component: Component):
"""Initialize a factorized component to use FISTA PGM for optimization"""
diff --git a/python/lsst/scarlet/lite/image.py b/python/lsst/scarlet/lite/image.py
index c692dc0..461079b 100644
--- a/python/lsst/scarlet/lite/image.py
+++ b/python/lsst/scarlet/lite/image.py
@@ -22,13 +22,14 @@
from __future__ import annotations
import operator
+from copy import deepcopy
from typing import Any, Callable, Sequence, cast
import numpy as np
from numpy.typing import DTypeLike
from .bbox import Box
-from .utils import ScalarLike, ScalarTypes
+from .utils import ScalarLike, ScalarTypes, convert_indices
__all__ = ["Image", "MismatchedBoxError", "MismatchedBandsError"]
@@ -54,7 +55,7 @@ def get_dtypes(*data: np.ndarray | Image | ScalarLike) -> list[DTypeLike]:
result:
A list of datatypes.
"""
- dtypes: list[DTypeLike] = [None] * len(data)
+ dtypes: list[DTypeLike] = [float] * len(data)
for d, element in enumerate(data):
if hasattr(element, "dtype"):
dtypes[d] = cast(np.ndarray, element).dtype
@@ -496,24 +497,7 @@ def spectral_indices(self, bands: Sequence | slice) -> tuple[int, ...] | slice:
band_indices:
Tuple of indices for each band in this image.
"""
- if isinstance(bands, slice):
- # Convert a slice of band names into a slice of array indices
- # to select the appropriate slice.
- if bands.start is None:
- start = None
- else:
- start = self.bands.index(bands.start)
- if bands.stop is None:
- stop = None
- else:
- stop = self.bands.index(bands.stop) + 1
- return slice(start, stop, bands.step)
-
- if isinstance(bands, str):
- return (self.bands.index(bands),)
-
- band_indices = tuple(self.bands.index(band) for band in bands if band in self.bands)
- return band_indices
+ return convert_indices(self.bands, bands)
def matched_spectral_indices(
self,
@@ -545,7 +529,8 @@ def matched_spectral_indices(
err = "Attempted to insert a multi-band image into a monochromatic image"
raise ValueError(err)
- self_indices = cast(tuple[int, ...], self.spectral_indices(other.bands))
+ common_bands = tuple(set(self.bands).intersection(set(other.bands)))
+ self_indices = cast(tuple[int, ...], self.spectral_indices(common_bands))
matched_bands = tuple(self.bands[bidx] for bidx in self_indices)
other_indices = cast(tuple[int, ...], other.spectral_indices(matched_bands))
return other_indices, self_indices
@@ -682,6 +667,45 @@ def repeat(self, bands: tuple) -> Image:
yx0=self.yx0,
)
+ def __copy__(self) -> Image:
+ """Make a copy of this image.
+
+ Returns
+ -------
+ image: Image
+ The copy of this image.
+ """
+ return self.copy_with()
+
+ def __deepcopy__(self, memo: dict[int, Any]) -> Image:
+ """Make a deep copy of this image.
+
+ Parameters
+ ----------
+ memo:
+ A dictionary of already copied objects to avoid infinite recursion.
+ Returns
+ -------
+ image: Image
+ The deep copy of this image.
+ """
+ # Check if already copied
+ if id(self) in memo:
+ return memo[id(self)]
+
+ # Create placeholder and add to memo FIRST
+ result = Image.__new__(Image)
+ memo[id(self)] = result
+
+ # Now safely initialize the placeholder with deepcopied arguments
+ result.__init__( # type: ignore[misc]
+ data=deepcopy(self.data, memo),
+ bands=deepcopy(self.bands, memo),
+ yx0=deepcopy(self.yx0, memo),
+ )
+
+ return result
+
def copy(self, order=None) -> Image:
"""Make a copy of this image.
diff --git a/python/lsst/scarlet/lite/io/blend.py b/python/lsst/scarlet/lite/io/blend.py
index 9a2db02..2f360e1 100644
--- a/python/lsst/scarlet/lite/io/blend.py
+++ b/python/lsst/scarlet/lite/io/blend.py
@@ -6,6 +6,7 @@
from typing import Any
import numpy as np
+from deprecated.sphinx import deprecated # type: ignore
from numpy.typing import DTypeLike
from ..bbox import Box
@@ -170,6 +171,11 @@ def to_blend(self, observation: Observation) -> Blend:
return Blend(sources=sources, observation=observation, metadata=self.metadata)
@staticmethod
+ @deprecated(
+ reason="ScarletBlendData.from_blend is deprecated. Use blend.to_data() instead.",
+ version="v30.0",
+ category=FutureWarning,
+ )
def from_blend(blend: Blend) -> ScarletBlendData:
"""Deprecated: Convert a scarlet lite blend into a storage data model.
@@ -182,7 +188,6 @@ def from_blend(blend: Blend) -> ScarletBlendData:
result :
The storage data model representing the blend.
"""
- logger.warning("ScarletBlendData.from_blend is deprecated. Use blend.to_data() instead.")
return blend.to_data()
diff --git a/python/lsst/scarlet/lite/io/blend_base.py b/python/lsst/scarlet/lite/io/blend_base.py
index 82b4dc5..7d71ce6 100644
--- a/python/lsst/scarlet/lite/io/blend_base.py
+++ b/python/lsst/scarlet/lite/io/blend_base.py
@@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import Any, ClassVar
+import numpy as np
from numpy.typing import DTypeLike
from ..bbox import Box
@@ -54,7 +55,7 @@ def as_dict(self) -> dict:
"""
@staticmethod
- def from_dict(data: dict, dtype: DTypeLike | None = None) -> ScarletBlendBaseData:
+ def from_dict(data: dict, dtype: DTypeLike = np.float32) -> ScarletBlendBaseData:
"""Reconstruct `ScarletBlendBaseData` from JSON compatible dict.
Parameters
diff --git a/python/lsst/scarlet/lite/io/component.py b/python/lsst/scarlet/lite/io/component.py
index f67541e..223e64d 100644
--- a/python/lsst/scarlet/lite/io/component.py
+++ b/python/lsst/scarlet/lite/io/component.py
@@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import ClassVar
+import numpy as np
from numpy.typing import DTypeLike
from ..component import Component
@@ -62,7 +63,7 @@ def as_dict(self) -> dict:
"""
@staticmethod
- def from_dict(data: dict, dtype: DTypeLike | None = None) -> ScarletComponentBaseData:
+ def from_dict(data: dict, dtype: DTypeLike = np.float32) -> ScarletComponentBaseData:
"""Reconstruct `ScarletComponentBaseData` from JSON compatible
dict.
diff --git a/python/lsst/scarlet/lite/io/cube_component.py b/python/lsst/scarlet/lite/io/cube_component.py
index 9d2acc9..3cd0f47 100644
--- a/python/lsst/scarlet/lite/io/cube_component.py
+++ b/python/lsst/scarlet/lite/io/cube_component.py
@@ -1,13 +1,14 @@
from __future__ import annotations
+import logging
from dataclasses import dataclass
-from typing import Callable
import numpy as np
+from deprecated.sphinx import deprecated # type: ignore
from numpy.typing import DTypeLike
from ..bbox import Box
-from ..component import Component
+from ..component import CubeComponent
from ..image import Image
from ..observation import Observation
from .component import ScarletComponentBaseData
@@ -19,68 +20,20 @@
COMPONENT_TYPE = "cube"
MigrationRegistry.set_current(COMPONENT_TYPE, CURRENT_SCHEMA)
+logger = logging.getLogger(__name__)
-class ComponentCube(Component):
- """Dummy component for a component cube.
- This is duck-typed to a `lsst.scarlet.lite.Component` in order to
- generate a model from the component but it is currently not functional
- in that it cannot be optimized, only persisted and loaded.
-
- If scarlet lite ever implements a component as a data cube,
- this class can be removed.
- """
+@deprecated(
+ reason="ComponentCube is deprecated and will be removed after scarlet_lite v30.0. "
+ "Please use CubeComponent instead.",
+ version="scarlet_lite v30.0",
+ category=FutureWarning,
+)
+class ComponentCube(CubeComponent):
+ """Deprecated, use CubeComponent instead."""
def __init__(self, model: Image, peak: tuple[int, int]):
- """Initialization
-
- Parameters
- ----------
- bands :
- model :
- The 3D (bands, y, x) model of the component.
- peak :
- The `(y, x)` peak of the component.
- bbox :
- The bounding box of the component.
- """
- super().__init__(model.bands, model.bbox)
- self._model = model
- self.peak = peak
-
- def get_model(self) -> Image:
- """Generate the model for the source
-
- Returns
- -------
- model :
- The model as a 3D `(band, y, x)` array.
- """
- return self._model
-
- def resize(self, model_box: Box) -> bool:
- """Test whether or not the component needs to be resized"""
- return False
-
- def update(self, it: int, input_grad: np.ndarray) -> None:
- """Implementation of unused abstract method"""
-
- def parameterize(self, parameterization: Callable) -> None:
- """Implementation of unused abstract method"""
-
- def to_data(self) -> ScarletCubeComponentData:
- """Convert the component to persistable ScarletComponentData
-
- Returns
- -------
- component_data: ScarletComponentData
- The data object containing the component information
- """
- return ScarletCubeComponentData(
- origin=self.bbox.origin, # type: ignore
- peak=self.peak, # type: ignore
- model=self.get_model().data,
- )
+ super().__init__(model=model, peak=peak)
@dataclass(kw_only=True)
@@ -110,7 +63,7 @@ class ScarletCubeComponentData(ScarletComponentBaseData):
def shape(self):
return self.model.shape[-2:]
- def to_component(self, observation: Observation) -> ComponentCube:
+ def to_component(self, observation: Observation) -> CubeComponent:
"""Convert the storage data model into a scarlet Component
Parameters
@@ -130,7 +83,7 @@ def to_component(self, observation: Observation) -> ComponentCube:
else:
peak = (int(np.round(self.peak[0])), int(np.round(self.peak[0])))
assert peak is not None
- component = ComponentCube(
+ component = CubeComponent(
model=Image(model, yx0=bbox.origin, bands=observation.bands), # type: ignore
peak=peak,
)
diff --git a/python/lsst/scarlet/lite/io/source.py b/python/lsst/scarlet/lite/io/source.py
index 5078c86..596b9a9 100644
--- a/python/lsst/scarlet/lite/io/source.py
+++ b/python/lsst/scarlet/lite/io/source.py
@@ -5,6 +5,7 @@
from typing import Any
import numpy as np
+from deprecated.sphinx import deprecated # type: ignore
from numpy.typing import DTypeLike
from ..component import Component
@@ -104,6 +105,11 @@ def to_source(self, observation: Observation) -> Source:
return Source(components=components, metadata=self.metadata)
@staticmethod
+ @deprecated(
+ reason="from_source is deprecated and will be removed in a future release.",
+ version="v30.0",
+ category=FutureWarning,
+ )
def from_source(source: Source) -> ScarletSourceData:
"""Deprecated: Create a `ScarletSourceData` from a scarlet `Source`
@@ -117,7 +123,6 @@ def from_source(source: Source) -> ScarletSourceData:
result:
The `ScarletSourceData` representation of the source.
"""
- logger.warning("from_source is deprecated and will be removed in a future release.")
return source.to_data()
diff --git a/python/lsst/scarlet/lite/models/free_form.py b/python/lsst/scarlet/lite/models/free_form.py
index 278e105..a1b1202 100644
--- a/python/lsst/scarlet/lite/models/free_form.py
+++ b/python/lsst/scarlet/lite/models/free_form.py
@@ -18,20 +18,25 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
+from __future__ import annotations
__all__ = ["FactorizedFreeFormComponent"]
-from typing import Callable, cast
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, Callable, cast
import numpy as np
-from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore
from ..bbox import Box
from ..component import Component, FactorizedComponent
from ..detect import footprints_to_image
+from ..detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore
from ..image import Image
from ..parameters import Parameter, parameter
+if TYPE_CHECKING:
+ from ..io.component import ScarletComponentBaseData
+
class FactorizedFreeFormComponent(FactorizedComponent):
"""Implements a free-form component
@@ -238,3 +243,90 @@ def __str__(self):
def __repr__(self):
return self.__str__()
+
+ def to_data(self) -> ScarletComponentBaseData:
+ raise NotImplementedError("Serialization not implemented for FreeFormComponent")
+
+ def __getitem__(self, indices: Any) -> FreeFormComponent:
+ """Get a sub-component corresponding to the given indices.
+
+ Parameters
+ ----------
+ indices: Any
+ The indices to use to slice the component model.
+
+ Returns
+ -------
+ component: FreeFormComponent
+ A new component that is a sub-component of this one.
+
+ Raises
+ ------
+ IndexError :
+ If the index includes a ``Box`` or spatial indices.
+ """
+ if indices in self.bands:
+ bands = (indices,)
+ else:
+ bands = tuple(indices)
+
+ return FreeFormComponent(
+ bands=bands,
+ model=self.model[indices],
+ model_bbox=self.bbox,
+ bg_thresh=self.bg_thresh,
+ bg_rms=self.bg_rms,
+ floor=self.floor,
+ peaks=self.peaks,
+ min_area=self.min_area,
+ )
+
+ def __deepcopy__(self, memo: dict[int, Any]) -> FreeFormComponent:
+ """Create a deep copy of this component.
+
+ Parameters
+ ----------
+ memo: dict[int, Any]
+ A dictionary to keep track of already copied objects.
+
+ Returns
+ -------
+ component : FreeFormComponent
+ A new component that is a deep copy of this one.
+ """
+ if id(self) in memo:
+ return memo[id(self)]
+
+ component = FreeFormComponent.__new__(FreeFormComponent)
+ memo[id(self)] = component
+
+ component.__init__( # type: ignore[misc]
+ bands=deepcopy(self.bands),
+ model=deepcopy(self.model),
+ model_bbox=deepcopy(self.bbox),
+ bg_thresh=self.bg_thresh,
+ bg_rms=deepcopy(self.bg_rms),
+ floor=self.floor,
+ peaks=deepcopy(self.peaks),
+ min_area=self.min_area,
+ )
+ return component
+
+ def __copy__(self) -> FreeFormComponent:
+ """Create a copy of this component.
+
+ Returns
+ -------
+ component : FreeFormComponent
+ A new component that is a copy of this one.
+ """
+ return FreeFormComponent(
+ bands=self.bands,
+ model=self.model,
+ model_bbox=self.bbox,
+ bg_thresh=self.bg_thresh,
+ bg_rms=self.bg_rms,
+ floor=self.floor,
+ peaks=self.peaks,
+ min_area=self.min_area,
+ )
diff --git a/python/lsst/scarlet/lite/models/parametric.py b/python/lsst/scarlet/lite/models/parametric.py
index ed05f1e..246bacc 100644
--- a/python/lsst/scarlet/lite/models/parametric.py
+++ b/python/lsst/scarlet/lite/models/parametric.py
@@ -37,7 +37,8 @@
"EllipticalParametricComponent",
]
-from typing import TYPE_CHECKING, Callable, Sequence, cast
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, Callable, Sequence, cast
import numpy as np
from scipy.special import erf
@@ -47,6 +48,7 @@
from ..component import Component
from ..image import Image
from ..parameters import Parameter, parameter
+from ..utils import convert_indices
if TYPE_CHECKING:
from ..io import ScarletComponentBaseData
@@ -842,6 +844,105 @@ def parameterize(self, parameterization: Callable) -> None:
def to_data(self) -> ScarletComponentBaseData:
raise NotImplementedError("Saving elliptical parametric components is not yet implemented")
+ def __getitem__(self, indices: Any) -> ParametricComponent:
+ """Get a sub-component corresponding to the given indices.
+
+ Parameters
+ ----------
+ indices: Any
+ The indices to use to slice the component model.
+
+ Returns
+ -------
+ component: ParametricComponent
+ A new component that is a sub-component of this one.
+
+ Raises
+ ------
+ IndexError :
+ If the index includes a ``Box`` or spatial indices.
+ """
+ # Update the bands
+ if indices in self._bands:
+ # Single band case
+ bands = (indices,)
+ else:
+ # Multiple bands case
+ bands = tuple(indices)
+
+ # Convert the band indices into numerical indices
+ band_indices = convert_indices(self.bands, indices)
+
+ # Slice the spectrum
+ spectrum = self._spectrum.x[band_indices]
+
+ return ParametricComponent(
+ bands=bands,
+ bbox=self.bbox,
+ spectrum=spectrum,
+ morph_params=self.radial_params,
+ morph_func=self._func,
+ morph_grad=self._morph_grad,
+ morph_prox=self._morph_prox,
+ morph_step=self._morph_step,
+ prox_spectrum=self._prox_spectrum,
+ floor=self.floor,
+ )
+
+ def __deepcopy__(self, memo: dict[int, Any]) -> ParametricComponent:
+ """Create a deep copy of this component
+
+ Parameters
+ ----------
+ memo:
+ The memoization dictionary used by `copy.deepcopy`.
+ Returns
+ -------
+ component : ParametricComponent
+ A new component that is a deep copy of this one.
+ """
+ if id(self) in memo:
+ return memo[id(self)]
+
+ component = ParametricComponent.__new__(ParametricComponent)
+ memo[id(self)] = component
+
+ component.__init__( # type: ignore[misc]
+ bands=deepcopy(self.bands),
+ bbox=deepcopy(self.bbox),
+ spectrum=deepcopy(self.spectrum),
+ morph_params=deepcopy(self.radial_params),
+ morph_func=self._func,
+ morph_grad=self._morph_grad,
+ morph_prox=self._morph_prox,
+ morph_step=self._morph_step,
+ prox_spectrum=self._prox_spectrum,
+ floor=self.floor,
+ )
+
+ return component
+
+ def __copy__(self) -> ParametricComponent:
+ """Create a copy of this component
+
+ Returns
+ -------
+ component : ParametricComponent
+ A new component that is a shallow copy of this one.
+ """
+ return ParametricComponent(
+ bands=self.bands,
+ bbox=self.bbox,
+ spectrum=self.spectrum,
+ morph_params=self.radial_params,
+ morph_func=self._func,
+ morph_grad=self._morph_grad,
+ morph_prox=self._morph_prox,
+ morph_step=self._morph_step,
+ prox_spectrum=self._prox_spectrum,
+ floor=self.floor,
+ )
+
class EllipticalParametricComponent(ParametricComponent):
"""A radial density/surface brightness profile with elliptical symmetry
diff --git a/python/lsst/scarlet/lite/observation.py b/python/lsst/scarlet/lite/observation.py
index 44fe404..56d13b9 100644
--- a/python/lsst/scarlet/lite/observation.py
+++ b/python/lsst/scarlet/lite/observation.py
@@ -23,6 +23,7 @@
__all__ = ["Observation", "convolve"]
+from copy import deepcopy
from typing import Any, cast
import numpy as np
@@ -63,9 +64,9 @@ def get_filter_coords(filter_values: np.ndarray, center: tuple[int, int] | None
calculate `coords` on your own."""
raise ValueError(msg)
center = tuple([filter_values.shape[0] // 2, filter_values.shape[1] // 2]) # type: ignore
- x = np.arange(filter_values.shape[1])
- y = np.arange(filter_values.shape[0])
- x, y = np.meshgrid(x, y)
+ _x = np.arange(filter_values.shape[1])
+ _y = np.arange(filter_values.shape[0])
+ x, y = np.meshgrid(_x, _y)
x -= center[1]
y -= center[0]
coords = np.dstack([y, x])
@@ -366,6 +367,29 @@ def __getitem__(self, indices: Any) -> Observation:
new_variance = self.variance[indices]
new_weights = self.weights[indices]
+ # If the indices is a single band, make sure to keep the band axis
+ if new_image.ndim == 2:
+ if indices in self.bands:
+ new_bands = (indices,)
+ else:
+ # The indices contain spatial and band indices
+ new_bands = (indices[0],)
+ new_image = Image(
+ new_image.data[None, :, :],
+ yx0=new_image.yx0,
+ bands=new_bands,
+ )
+ new_variance = Image(
+ new_variance.data[None, :, :],
+ yx0=new_variance.yx0,
+ bands=new_bands,
+ )
+ new_weights = Image(
+ new_weights.data[None, :, :],
+ yx0=new_weights.yx0,
+ bands=new_bands,
+ )
+
# Extract the appropriate bands from the PSF
bands = self.images.bands
new_bands = new_image.bands
@@ -385,56 +409,67 @@ def __getitem__(self, indices: Any) -> Observation:
model_psf=self.model_psf,
noise_rms=noise_rms,
bbox=new_image.bbox,
- bands=self.bands,
+ bands=new_bands,
padding=self.padding,
convolution_mode=self.mode,
)
- def __copy__(self, deep: bool = False) -> Observation:
+ def __copy__(self) -> Observation:
"""Create a copy of the observation
- Parameters
- ----------
- deep:
- Whether to perform a deep copy or not.
-
Returns
-------
result:
The copy of the observation.
"""
- if deep:
- if self.model_psf is None:
- model_psf = None
- else:
- model_psf = self.model_psf.copy()
+ return Observation(
+ images=self.images,
+ variance=self.variance,
+ weights=self.weights,
+ psfs=self.psfs,
+ model_psf=self.model_psf,
+ noise_rms=self.noise_rms,
+ bands=self.bands,
+ padding=self.padding,
+ convolution_mode=self.mode,
+ )
- if self.noise_rms is None:
- noise_rms = None
- else:
- noise_rms = self.noise_rms.copy()
+ def __deepcopy__(self, memo: dict[int, Any]) -> Observation:
+ """Create a deep copy of the observation
- if self.bands is None:
- bands = None
- else:
- bands = tuple([b for b in self.bands])
- else:
- model_psf = self.model_psf
- noise_rms = self.noise_rms
- bands = self.bands
+ Parameters
+ ----------
+ memo: dict[int, Any]
+ The memoization dictionary used by `copy.deepcopy`.
- return Observation(
- images=self.images.copy(),
- variance=self.variance.copy(),
- weights=self.weights.copy(),
- psfs=self.psfs.copy(),
- model_psf=model_psf,
- noise_rms=noise_rms,
- bands=bands,
- padding=self.padding,
+ Returns
+ -------
+ result:
+ The deep copy of the observation.
+ """
+ # Check if already copied
+ if id(self) in memo:
+ return memo[id(self)]
+
+ # Create placeholder and add to memo FIRST
+ result = Observation.__new__(Observation)
+ memo[id(self)] = result
+
+ # Now safely initialize the placeholder with deepcopied arguments
+ result.__init__( # type: ignore[misc]
+ images=deepcopy(self.images, memo),
+ variance=deepcopy(self.variance, memo),
+ weights=deepcopy(self.weights, memo),
+ psfs=deepcopy(self.psfs, memo),
+ model_psf=deepcopy(self.model_psf, memo),
+ noise_rms=deepcopy(self.noise_rms, memo),
+ bands=deepcopy(self.bands, memo),
+ padding=deepcopy(self.padding, memo),
convolution_mode=self.mode,
)
+ return result
+
def copy(self, deep: bool = False) -> Observation:
"""Create a copy of the observation
@@ -448,7 +483,9 @@ def copy(self, deep: bool = False) -> Observation:
result:
The copy of the observation.
"""
- return self.__copy__(deep)
+ if deep:
+ return self.__deepcopy__({})
+ return self.__copy__()
@property
def shape(self) -> tuple[int, int, int]:
diff --git a/python/lsst/scarlet/lite/operators.py b/python/lsst/scarlet/lite/operators.py
index 1fc9f74..f1d3c6b 100644
--- a/python/lsst/scarlet/lite/operators.py
+++ b/python/lsst/scarlet/lite/operators.py
@@ -1,4 +1,6 @@
-from typing import Callable, Sequence, cast
+from __future__ import annotations
+
+from typing import Any, Callable, Sequence, cast
import numpy as np
import numpy.typing as npt
@@ -251,6 +253,32 @@ def __call__(self, image: np.ndarray, center: tuple[int, int]) -> np.ndarray:
image[:] = result[1:-1, 1:-1]
return image
+ def __copy__(self) -> Monotonicity:
+ """Create a shallow copy of the operator
+
+ Returns
+ -------
+ result:
+ A copy of the operator.
+ """
+ new = Monotonicity(self.shape, self.dtype, self.auto_update, self.fit_radius)
+ return new
+
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Monotonicity:
+ """Create a deep copy of the operator
+
+ Parameters
+ ----------
+ memo:
+ The memoization dictionary for deep copies.
+
+ Returns
+ -------
+ result:
+ A copy of the operator.
+ """
+ return self.__copy__()
+
def get_peak(image: np.ndarray, center: tuple[int, int], radius: int = 1) -> tuple[int, int]:
"""Search around a location for the maximum flux
diff --git a/python/lsst/scarlet/lite/parameters.py b/python/lsst/scarlet/lite/parameters.py
index 5790c4f..03a16ab 100644
--- a/python/lsst/scarlet/lite/parameters.py
+++ b/python/lsst/scarlet/lite/parameters.py
@@ -32,7 +32,8 @@
"DEFAULT_ADAPROX_FACTOR",
]
-from typing import Callable, Sequence, cast
+from copy import deepcopy
+from typing import Any, Callable, Sequence, cast
import numpy as np
import numpy.typing as npt
@@ -120,11 +121,50 @@ def dtype(self) -> npt.DTypeLike:
"""The numpy dtype of the array that is being fit."""
return self.x.dtype
- def copy(self) -> Parameter:
- """Copy this parameter, including all of the helper arrays."""
+ def __copy__(self) -> Parameter:
+ """Create a shallow copy of this parameter.
+
+ Returns
+ -------
+ parameter:
+ A shallow copy of this parameter.
+ """
helpers = {k: v.copy() for k, v in self.helpers.items()}
return Parameter(self.x.copy(), helpers, 0)
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Parameter:
+ """Create a deep copy of this parameter.
+
+ Parameters
+ ----------
+ memo:
+ A memoization dictionary used by `copy.deepcopy`.
+ Returns
+ -------
+ parameter:
+ A deep copy of this parameter.
+ """
+ helpers = {k: deepcopy(v, memo) for k, v in self.helpers.items()}
+ return Parameter(deepcopy(self.x, memo), helpers, 0)
+
+ def copy(self, deep: bool = False) -> Parameter:
+ """Copy this parameter, including all of the helper arrays.
+
+ Parameters
+ ----------
+ deep:
+ If `True`, a deep copy is made.
+ If `False`, a shallow copy is made.
+
+ Returns
+ -------
+ parameter:
+ A copy of this parameter.
+ """
+ if deep:
+ return self.__deepcopy__({})
+ return self.__copy__()
+
def update(self, it: int, input_grad: np.ndarray, *args):
"""Update the parameter in one iteration.
@@ -197,7 +237,7 @@ def __init__(
z0: np.ndarray | None = None,
):
if z0 is None:
- z0 = x
+ z0 = x.copy()
super().__init__(
x,
@@ -231,6 +271,44 @@ def update(self, it: int, input_grad: np.ndarray, *args):
_x[:] = x
self.t = t
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FistaParameter:
+ """Create a deep copy of this parameter.
+
+ Parameters
+ ----------
+ memo:
+ A memoization dictionary used by `copy.deepcopy`.
+ Returns
+ -------
+ parameter:
+ A deep copy of this parameter.
+ """
+ return FistaParameter(
+ deepcopy(self.x, memo),
+ self.step,
+ self.grad,
+ self.prox,
+ self.t,
+ deepcopy(self.helpers["z"], memo),
+ )
+
+ def __copy__(self) -> FistaParameter:
+ """Create a shallow copy of this parameter.
+
+ Returns
+ -------
+ parameter:
+ A shallow copy of this parameter.
+ """
+ return FistaParameter(
+ self.x.copy(),
+ self.step,
+ self.grad,
+ self.prox,
+ self.t,
+ self.helpers["z"].copy(),
+ )
+
# The following code block contains different update methods for
# various implementations of ADAM.
@@ -375,7 +453,7 @@ def __init__(
step: Callable | float,
grad: Callable | None = None,
prox: Callable | None = None,
- b1: float = 0.9,
+ b1: float | SingleItemArray = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
p: float = 0.25,
@@ -418,6 +496,7 @@ def __init__(
self.eps = eps
self.p = p
+ self.scheme = scheme
self.phi_psi = phi_psi[scheme]
self.e_rel = prox_e_rel
@@ -453,6 +532,58 @@ def update(self, it: int, input_grad: np.ndarray, *args):
self.x = cast(Callable, self.prox)(_x)
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> AdaproxParameter:
+ """Create a deep copy of this parameter.
+
+ Parameters
+ ----------
+ memo:
+ A memoization dictionary used by `copy.deepcopy`.
+ Returns
+ -------
+ parameter:
+ A deep copy of this parameter.
+ """
+ return AdaproxParameter(
+ deepcopy(self.x, memo),
+ self.step,
+ self.grad,
+ self.prox,
+ self.b1,
+ self.b2,
+ self.eps,
+ self.p,
+ deepcopy(self.helpers["m"], memo),
+ deepcopy(self.helpers["v"], memo),
+ deepcopy(self.helpers["vhat"], memo),
+ scheme=self.scheme,
+ prox_e_rel=self.e_rel,
+ )
+
+ def __copy__(self) -> AdaproxParameter:
+ """Create a shallow copy of this parameter.
+
+ Returns
+ -------
+ parameter:
+ A shallow copy of this parameter.
+ """
+ return AdaproxParameter(
+ self.x,
+ self.step,
+ self.grad,
+ self.prox,
+ self.b1,
+ self.b2,
+ self.eps,
+ self.p,
+ self.helpers["m"],
+ self.helpers["v"],
+ self.helpers["vhat"],
+ scheme=self.scheme,
+ prox_e_rel=self.e_rel,
+ )
+
class FixedParameter(Parameter):
"""A parameter that is not updated"""
@@ -463,6 +594,31 @@ def __init__(self, x: np.ndarray):
def update(self, it: int, input_grad: np.ndarray, *args):
pass
+ def __copy__(self) -> FixedParameter:
+ """Create a shallow copy of this parameter.
+
+ Returns
+ -------
+ parameter:
+ A shallow copy of this parameter.
+ """
+ return FixedParameter(self.x)
+
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FixedParameter:
+ """Create a deep copy of this parameter.
+
+ Parameters
+ ----------
+ memo:
+ A memoization dictionary used by `copy.deepcopy`.
+
+ Returns
+ -------
+ parameter:
+ A deep copy of this parameter.
+ """
+ return FixedParameter(deepcopy(self.x, memo))
+
def relative_step(
x: np.ndarray,
diff --git a/python/lsst/scarlet/lite/source.py b/python/lsst/scarlet/lite/source.py
index c17afcc..5acab55 100644
--- a/python/lsst/scarlet/lite/source.py
+++ b/python/lsst/scarlet/lite/source.py
@@ -24,7 +24,8 @@
__all__ = ["Source"]
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, Callable
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, Callable, Self
from .bbox import Box
from .component import Component
@@ -42,6 +43,7 @@ class SourceBase(ABC):
"""
metadata: dict[str, Any] | None = None
+ components: list[Component]
@abstractmethod
def to_data(self) -> ScarletSourceBaseData:
@@ -53,6 +55,69 @@ def to_data(self) -> ScarletSourceBaseData:
The `ScarletSourceData` representation of this source.
"""
+ @abstractmethod
+ def __getitem__(self, indices: Any) -> Self:
+ """Get a sub-source corresponding to the given indices.
+
+ Parameters
+ ----------
+ indices: Any
+ The indices to use to slice the source model.
+
+ Returns
+ -------
+ source: SourceBase
+ A new source that is a sub-source of this one.
+
+ Raises
+ ------
+ IndexError :
+ If the index includes a ``Box`` or spatial indices.
+ """
+
+ @abstractmethod
+ def __deepcopy__(self, memo: dict[int, Any]) -> Self:
+ """Create a deep copy of this source.
+
+ Parameters
+ ----------
+ memo : dict[int, Any]
+ A memoization dictionary used by `copy.deepcopy`.
+
+ Returns
+ -------
+ source : SourceBase
+ A new source that is a deep copy of this one.
+ """
+
+ @abstractmethod
+ def __copy__(self) -> Self:
+ """Create a copy of this source.
+
+ Returns
+ -------
+ source : SourceBase
+ A new source that is a copy of this one.
+ """
+
+ def copy(self, deep: bool = False) -> Self:
+ """Create a copy of this source.
+
+ Parameters
+ ----------
+ deep : bool, optional
+ If `True`, a deep copy is made. If `False`, a shallow copy is made.
+ Default is `False`.
+
+ Returns
+ -------
+ source : Self
+ A new source that is a copy of this one.
+ """
+ if deep:
+ return self.__deepcopy__({})
+ return self.__copy__()
+
class Source(SourceBase):
"""A container for components associated with the same astrophysical object
@@ -66,9 +131,14 @@ class Source(SourceBase):
The components contained in the source.
"""
- def __init__(self, components: list[Component], metadata: dict | None = None):
+ def __init__(
+ self,
+ components: list[Component],
+ metadata: dict | None = None,
+ flux_weighted_image: Image | None = None,
+ ):
self.components = components
- self.flux_weighted_image: Image | None = None
+ self.flux_weighted_image = flux_weighted_image
self.metadata = metadata
@property
@@ -182,3 +252,74 @@ def __str__(self):
def __repr__(self):
return f"Source(components={repr(self.components)})>"
+
+ def __getitem__(self, indices: Any) -> Source:
+ """Get a sub-source corresponding to the given indices.
+
+ Parameters
+ ----------
+ indices: Any
+ The indices to use to slice the source model. Can be:
+ - A single band
+ - A slice with start/stop bands
+ - A sequence of bands
+
+ Returns
+ -------
+ source: Source
+ A new source that is a sub-source of this one.
+
+ Raises
+ ------
+ IndexError :
+ If the index includes a ``Box`` or spatial indices.
+ """
+ flux = None if self.flux_weighted_image is None else self.flux_weighted_image[indices]
+ return Source(
+ components=[c[indices] for c in self.components],
+ metadata=self.metadata,
+ flux_weighted_image=flux,
+ )
+
+ def __deepcopy__(self, memo: dict[int, Any]) -> Source:
+ """Create a deep copy of this source.
+
+ Parameters
+ ----------
+ memo : dict[int, Any]
+ A memoization dictionary used by `copy.deepcopy`.
+
+ Returns
+ -------
+ source : SourceBase
+ A new source that is a deep copy of this one.
+ """
+ # Check if already copied
+ if id(self) in memo:
+ return memo[id(self)]
+
+ # Create placeholder and add to memo FIRST
+ source = Source.__new__(Source)
+ memo[id(self)] = source
+
+ source.__init__( # type: ignore[misc]
+ components=deepcopy(self.components, memo),
+ metadata=deepcopy(self.metadata, memo),
+ flux_weighted_image=deepcopy(self.flux_weighted_image, memo),
+ )
+ return source
+
+ def __copy__(self) -> Source:
+ """Create a copy of this source.
+
+ Returns
+ -------
+ source : SourceBase
+ A new source that is a copy of this one.
+ """
+ source = Source(
+ components=self.components,
+ metadata=self.metadata,
+ flux_weighted_image=self.flux_weighted_image,
+ )
+ return source
diff --git a/python/lsst/scarlet/lite/utils.py b/python/lsst/scarlet/lite/utils.py
index a85fb56..6340323 100644
--- a/python/lsst/scarlet/lite/utils.py
+++ b/python/lsst/scarlet/lite/utils.py
@@ -19,7 +19,10 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
+from __future__ import annotations
+
import sys
+from typing import Any, Sequence
import numpy as np
import numpy.typing as npt
@@ -158,6 +161,76 @@ def is_attribute_safe_to_transfer(name, value):
return True
+def convert_indices(sequence: Sequence, indices: Any, inclusive: bool = True) -> tuple[int, ...] | slice:
+ """Get either a tuple of indices or a slice object from the given sequence.
+
+ Parameters
+ ----------
+ sequence : Sequence
+ The sequence to get the indices from. This sequence should have
+ unique hashable elements.
+
+ indices : Any
+ The indices or slice to use. Can be:
+ - A single element from sequence
+ - A slice with start/stop elements from sequence
+ - A sequence of elements from sequence
+
+ inclusive : bool, optional
+ If True, the stop element of a slice is inclusive.
+
+ Returns
+ -------
+ tuple[int, ...] | slice
+ A tuple of indices or a slice object.
+
+ Raises
+ ------
+ TypeError :
+ If `sequence` does not support `index` and `in` operations.
+ IndexError :
+ If a single element is not found in `sequence`.
+ """
+ # Validate that sequence has the required methods
+ if not hasattr(sequence, "index") or not hasattr(sequence, "__contains__"):
+ raise TypeError(f"'sequence' must support 'index' and 'in' operations, got {type(sequence)}")
+
+ # Handle slice objects
+ if isinstance(indices, slice):
+ # Convert a slice of objects into a slice of array indices
+ try:
+ start = None if indices.start is None else sequence.index(indices.start)
+ except ValueError as e:
+ raise IndexError(f"Element {indices.start} not found in sequence {sequence}.") from e
+ try:
+ stop = None if indices.stop is None else sequence.index(indices.stop) + (1 if inclusive else 0)
+ except ValueError as e:
+ raise IndexError(f"Element {indices.stop} not found in sequence {sequence}.") from e
+ return slice(start, stop, indices.step)
+
+ # Try to handle as a single element first
+ if indices in sequence:
+ return (sequence.index(indices),)
+
+ # Validate that indices is iterable
+ if not hasattr(indices, "__iter__"):
+ raise IndexError(f"Element {indices} not found in sequence {sequence}.")
+
+ # Handle sequence of indices
+ index_map = {value: idx for idx, value in enumerate(sequence)}
+ new_indices = []
+ for i in indices:
+ try:
+ if i not in index_map:
+ raise IndexError(f"Element {i} not found in sequence {sequence}.")
+ except TypeError as e:
+ # If the
+ raise IndexError(f"Element {i} not found in sequence {sequence}.") from e
+ new_indices.append(index_map[i])
+
+ return tuple(new_indices)
+
+
def continue_class(cls):
"""Re-open the decorated class, adding any new definitions into the
original.
diff --git a/tests/test_bbox.py b/tests/test_bbox.py
index cc4cbb8..85856bf 100644
--- a/tests/test_bbox.py
+++ b/tests/test_bbox.py
@@ -19,6 +19,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
+from copy import deepcopy
+
import numpy as np
from lsst.scarlet.lite import Box
from utils import ScarletTestCase
@@ -213,3 +215,15 @@ def test_slicing(self):
self.assertBoxEqual(bbox[:3], Box((1, 2, 3), (2, 4, 6)))
# check tuple index
self.assertBoxEqual(bbox[(3, 1)], Box((4, 2), (8, 4)))
+
+ def test_shallow_copy(self):
+ bbox = Box((3, 4, 5), (10, 20, 30))
+ bbox_copy = bbox.copy()
+ self.assertBoxEqual(bbox, bbox_copy)
+ self.assertIsNot(bbox, bbox_copy)
+
+ def test_deepcopy(self):
+ bbox = Box((3, 4, 5), (10, 20, 30))
+ bbox_deepcopy = deepcopy(bbox)
+ self.assertBoxEqual(bbox, bbox_deepcopy)
+ self.assertIsNot(bbox, bbox_deepcopy)
diff --git a/tests/test_blend.py b/tests/test_blend.py
index 9604993..e27295e 100644
--- a/tests/test_blend.py
+++ b/tests/test_blend.py
@@ -21,45 +21,19 @@
from __future__ import annotations
-from typing import Callable, cast
+from typing import cast
import numpy as np
from lsst.scarlet.lite import Blend, Box, Image, Observation, Source
-from lsst.scarlet.lite.component import Component, FactorizedComponent, default_adaprox_parameterization
+from lsst.scarlet.lite.component import CubeComponent, FactorizedComponent, default_adaprox_parameterization
from lsst.scarlet.lite.initialization import FactorizedInitialization
from lsst.scarlet.lite.operators import Monotonicity
-from lsst.scarlet.lite.parameters import Parameter
from lsst.scarlet.lite.utils import integrated_circular_gaussian
from numpy.testing import assert_almost_equal, assert_raises
from scipy.signal import convolve as scipy_convolve
from utils import ObservationData, ScarletTestCase
-class DummyCubeComponent(Component):
- def __init__(self, model: Image):
- super().__init__(model.bands, model.bbox)
- self._model = Parameter(model.data, {}, 0)
-
- @property
- def data(self) -> np.ndarray:
- return self._model.x
-
- def resize(self, model_box: Box) -> bool:
- pass
-
- def update(self, it: int, input_grad: np.ndarray):
- pass
-
- def get_model(self) -> Image:
- return Image(self.data, bands=self.bands, yx0=self.bbox.origin)
-
- def parameterize(self, parameterization: Callable) -> None:
- pass
-
- def to_data(self) -> DummyCubeComponent:
- pass
-
-
class TestBlend(ScarletTestCase):
def setUp(self):
bands = ("g", "r", "i")
@@ -230,7 +204,7 @@ def test_non_factorized(self):
# Remove the disk component from the first source
blend.sources[0].components = blend.sources[0].components[:1]
# Create a new source for the disk with a non-factorized component
- component = DummyCubeComponent(Image(model, bands=self.blend.observation.bands, yx0=yx0))
+ component = CubeComponent(Image(model, bands=self.blend.observation.bands, yx0=yx0), (0, 0))
blend.sources.append(Source([component]))
blend.fit_spectra()
@@ -254,7 +228,7 @@ def test_clipping(self):
# Add an empty source
zero_model = Image.from_box(Box((5, 5), (30, 0)), bands=blend.observation.bands)
- component = DummyCubeComponent(zero_model)
+ component = CubeComponent(zero_model, (0, 0))
blend.sources.append(Source([component]))
blend.fit_spectra(clip=True)
@@ -262,3 +236,107 @@ def test_clipping(self):
self.assertEqual(len(blend.components), 5)
self.assertEqual(len(blend.sources), 5)
self.assertImageAlmostEqual(blend.get_model(), self.data.images)
+
+ def test_shallow_copy(self):
+ blend = self.blend
+ blend.metadata = {"test": "value"}
+ blend_copy = blend.copy()
+
+ self.assertIsNot(blend_copy, blend)
+ self.assertEqual(len(blend_copy.sources), len(blend.sources))
+ for source_copy, source in zip(blend_copy.sources, blend.sources):
+ self.assertSourceEqual(source_copy, source)
+
+ self.assertObservationEqual(blend_copy.observation, blend.observation)
+
+ self.assertDictEqual(blend_copy.metadata, blend.metadata)
+
+ def test_deepcopy(self):
+ blend = self.blend
+ blend.metadata = {"test": "value"}
+ blend_copy = blend.copy(deep=True)
+
+ self.assertIsNot(blend_copy, blend)
+ self.assertEqual(len(blend_copy.sources), len(blend.sources))
+ for source_copy, source in zip(blend_copy.sources, blend.sources):
+ self.assertSourceEqual(source_copy, source)
+
+ with self.assertRaises(AssertionError):
+ source_copy.components[0]._spectrum.x += 1
+ self.assertSourceEqual(source_copy, source)
+
+ self.assertObservationEqual(blend_copy.observation, blend.observation)
+ self.assertDictEqual(blend_copy.metadata, blend.metadata)
+ blend_copy.metadata["test"] = "new_value"
+ with self.assertRaises(AssertionError):
+ self.assertDictEqual(blend_copy.metadata, blend.metadata)
+
+ def test_slice(self):
+ blend = self.blend
+ blend.metadata = {"test": "value"}
+ blend_sliced = blend["g":"r"]
+ self.assertEqual(len(blend.sources), len(blend_sliced.sources))
+
+ for source_sliced, source in zip(blend_sliced.sources, blend.sources):
+ self.assertSourceEqual(source_sliced, source["g":"r"])
+
+ self.assertObservationEqual(blend_sliced.observation, blend.observation["g":"r"])
+ self.assertDictEqual(blend_sliced.metadata, blend.metadata)
+
+ def test_reorder(self):
+ blend = self.blend
+ blend.metadata = {"test": "value"}
+ indices = ("i", "g", "r")
+ blend_reordered = blend[indices]
+ self.assertEqual(len(blend.sources), len(blend_reordered.sources))
+
+ for source_reordered, source in zip(blend_reordered.sources, blend.sources):
+ self.assertSourceEqual(source_reordered, source[indices])
+
+ self.assertObservationEqual(blend_reordered.observation, blend.observation[indices])
+ self.assertDictEqual(blend_reordered.metadata, blend.metadata)
+
+ def test_subset(self):
+ blend = self.blend
+ blend.metadata = {"test": "value"}
+ blend_subset = blend[("r",)]
+ self.assertEqual(len(blend.sources), len(blend_subset.sources))
+
+ for source_subset, source in zip(blend_subset.sources, blend.sources):
+ self.assertSourceEqual(source_subset, source["r"])
+
+ self.assertObservationEqual(blend_subset.observation, blend.observation["r"])
+ self.assertDictEqual(blend_subset.metadata, blend.metadata)
+
+ def test_indexing_errors(self):
+ blend = self.blend
+
+ with self.assertRaises(IndexError):
+ blend["x"]
+
+ with self.assertRaises(IndexError):
+ blend[("r", "x")]
+
+ with self.assertRaises(IndexError):
+ blend["r":"x"]
+
+ with self.assertRaises(IndexError):
+ blend["x":"i"]
+
+ with self.assertRaises(IndexError):
+ blend["g", "x", "i"]
+
+ with self.assertRaises(IndexError):
+ blend[Box((0, 0), (10, 10))]
+
+ with self.assertRaises(IndexError):
+ blend[:, 10:20, 10:20]
+
+ with self.assertRaises(IndexError):
+ blend[1:]
+
+ with self.assertRaises(IndexError):
+ blend[1]
+
+ with self.assertRaises(IndexError):
+ blend[0, 1]
diff --git a/tests/test_component.py b/tests/test_component.py
index e054675..85871cc 100644
--- a/tests/test_component.py
+++ b/tests/test_component.py
@@ -21,17 +21,20 @@
from __future__ import annotations
-from typing import Callable
+from abc import ABC
+from typing import Any, Callable
import numpy as np
from lsst.scarlet.lite import Box, Image, Parameter
from lsst.scarlet.lite.component import (
Component,
+ CubeComponent,
FactorizedComponent,
default_adaprox_parameterization,
default_fista_parameterization,
)
from lsst.scarlet.lite.operators import Monotonicity
+from lsst.scarlet.lite.utils import integrated_circular_gaussian
from numpy.testing import assert_almost_equal, assert_array_equal
from utils import ScarletTestCase
@@ -52,8 +55,92 @@ def parameterize(self, parameterization: Callable) -> None:
def to_data(self) -> DummyComponent:
pass
+ def __getitem__(self, indices: Any) -> DummyComponent:
+ pass
+
+ def __copy__(self) -> DummyComponent:
+ pass
+
+ def __deepcopy__(self, memo: dict[int, Any]) -> DummyComponent:
+ pass
+
+
+class _ComponentTestBase(ABC):
+ def test_slice(self):
+ component = self.component
+ component_sliced = component["g":"r"]
+ self.assertTupleEqual(component_sliced.bands, ("g", "r"))
+ np.testing.assert_array_equal(component_sliced.get_model(), component.get_model().data[0:2])
+
+ def test_reorder(self):
+ component = self.component
+ indices = ("i", "g", "r")
+ component_reordered = component["i", "g", "r"]
+ self.assertTupleEqual(component_reordered.bands, indices)
+ np.testing.assert_array_equal(
+ component_reordered.get_model(),
+ component.get_model().data[(2, 0, 1),],
+ )
+
+ component_reordered = component["igr"]
+ self.assertTupleEqual(component_reordered.bands, indices)
+ np.testing.assert_array_equal(
+ component_reordered.get_model(),
+ component.get_model().data[(2, 0, 1),],
+ )
+
+ def test_subset(self):
+ component = self.component
+ indices = ("r",)
+ component_subset = component["r"]
+ self.assertTupleEqual(component_subset.bands, indices)
+ np.testing.assert_array_equal(
+ component_subset.get_model(),
+ component.get_model().data[1:2,],
+ )
+
+ component = self.component.copy(deep=True)
+ component._bands = ("ab", "cd", "ef")
+ indices = "ab"
+ component_reordered = component["ab"]
+ self.assertTupleEqual(component_reordered.bands, (indices,))
+ np.testing.assert_array_equal(
+ component_reordered.get_model(),
+ component.get_model().data[0:1,],
+ )
+
+ def test_indexing_errors(self):
+ component = self.component
+ print("bands", component.bands)
+ with self.assertRaises(IndexError):
+ component["z"]
+
+ with self.assertRaises(IndexError):
+ component["r":"z"]
+
+ with self.assertRaises(IndexError):
+ component["z":"i"]
+
+ with self.assertRaises(IndexError):
+ component["g", "z", "i"]
+
+ with self.assertRaises(IndexError):
+ component[Box((0, 0), (10, 10))]
+
+ with self.assertRaises(IndexError):
+ component[:, 10:20, 10:20]
+
+ with self.assertRaises(IndexError):
+ component[1:]
+
+ with self.assertRaises(IndexError):
+ component[1]
-class TestFactorizedComponent(ScarletTestCase):
+ with self.assertRaises(IndexError):
+ component[0, 1]
+
+
+class TestFactorizedComponent(_ComponentTestBase, ScarletTestCase):
def setUp(self) -> None:
spectrum = np.arange(3).astype(np.float32)
morph = np.arange(20).reshape(4, 5).astype(np.float32)
@@ -246,3 +333,82 @@ def test_parameterization(self):
with self.assertRaises(NotImplementedError):
default_adaprox_parameterization(DummyComponent(*params))
+
+ def test_shallow_copy(self):
+ component = self.component
+ component.monotonicity = Monotonicity((11, 11), fit_radius=0)
+
+ component_copy = component.copy()
+
+ self.assertIsNot(component, component_copy)
+ np.testing.assert_array_equal(component._spectrum.x, component_copy._spectrum.x)
+ np.testing.assert_array_equal(component._morph.x, component_copy._morph.x)
+ self.assertIs(component.bbox, component_copy.bbox)
+ self.assertIs(component.peak, component_copy.peak)
+ self.assertIs(component.bg_thresh, component_copy.bg_thresh)
+ self.assertIs(component.monotonicity, component_copy.monotonicity)
+
+ def test_deep_copy(self):
+ component = self.component
+ component.monotonicity = Monotonicity((11, 11), fit_radius=0)
+ component_deepcopy = component.copy(deep=True)
+
+ self.assertIsNot(component, component_deepcopy)
+
+ np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x)
+ component_deepcopy._spectrum.x += 1
+ with self.assertRaises(AssertionError):
+ np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x)
+
+ np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x)
+ component_deepcopy._morph.x += 1
+ with self.assertRaises(AssertionError):
+ np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x)
+
+ self.assertIsNot(component.bbox, component_deepcopy.bbox)
+ self.assertBoxEqual(component.bbox, component_deepcopy.bbox)
+
+ self.assertTupleEqual(component.peak, component_deepcopy.peak)
+ self.assertEqual(component.bg_thresh, component_deepcopy.bg_thresh)
+ self.assertIsNot(component.monotonicity, component_deepcopy.monotonicity)
+
+
+class TestCubeComponent(_ComponentTestBase, ScarletTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+ self.bands = tuple("gri")
+ peak = (27, 32)
+ bbox = Box((15, 15), (20, 25))
+ morph = integrated_circular_gaussian(sigma=0.8).astype(np.float32)
+ spectrum = np.arange(3, dtype=np.float32)
+ model = morph[None, :, :] * spectrum[:, None, None]
+ model_image = Image(model, yx0=bbox.origin, bands=self.bands)
+ self.component = CubeComponent(model=model_image, peak=peak)
+
+ def test_constructor(self):
+ component = self.component
+ self.assertIsInstance(component._model, Image)
+ np.testing.assert_array_equal(component._model.data, self.component._model.data)
+ self.assertTupleEqual(component.bands, self.bands)
+ self.assertBoxEqual(component.bbox, Box((15, 15), (20, 25)))
+ self.assertTupleEqual(component.peak, (27, 32))
+
+ def test_shallow_copy(self):
+ component = self.component
+ component_copy = component.copy()
+
+ self.assertIsNot(component_copy, component)
+ self.assertTupleEqual(component_copy.peak, component.peak)
+ self.assertImageEqual(component_copy._model, component._model)
+
+ def test_deep_copy(self):
+ component = self.component
+ component_copy = component.copy(deep=True)
+
+ self.assertIsNot(component, component_copy)
+
+ self.assertTupleEqual(component_copy.peak, component.peak)
+ self.assertImageEqual(component_copy._model, component._model)
+ with self.assertRaises(AssertionError):
+ component_copy._model._data -= 1
+ self.assertImageEqual(component_copy._model, component._model)
diff --git a/tests/test_io.py b/tests/test_io.py
index 9e8cb4a..9b2b8ed 100644
--- a/tests/test_io.py
+++ b/tests/test_io.py
@@ -24,8 +24,8 @@
import numpy as np
from lsst.scarlet.lite import Blend, Image, Observation, io
+from lsst.scarlet.lite.component import CubeComponent
from lsst.scarlet.lite.initialization import FactorizedInitialization
-from lsst.scarlet.lite.io import ComponentCube
from lsst.scarlet.lite.operators import Monotonicity
from lsst.scarlet.lite.utils import integrated_circular_gaussian
from numpy.testing import assert_almost_equal
@@ -107,7 +107,7 @@ def test_cube_component(self):
blend.sources[i].metadata = {"id": f"peak-{i}"}
component = blend.sources[-1].components[-1]
# Replace one of the components with a Free-Form component.
- blend.sources[-1].components[-1] = ComponentCube(
+ blend.sources[-1].components[-1] = CubeComponent(
model=component.get_model(),
peak=component.peak,
)
diff --git a/tests/test_measure.py b/tests/test_measure.py
index 1a6eae3..9082cf6 100644
--- a/tests/test_measure.py
+++ b/tests/test_measure.py
@@ -22,8 +22,8 @@
import os
import numpy as np
-from lsst.scarlet.lite import Blend, Image, Observation, Source, io
-from lsst.scarlet.lite.component import default_adaprox_parameterization
+from lsst.scarlet.lite import Blend, Image, Observation, Source
+from lsst.scarlet.lite.component import CubeComponent, default_adaprox_parameterization
from lsst.scarlet.lite.initialization import FactorizedInitialization
from lsst.scarlet.lite.measure import calculate_snr
from lsst.scarlet.lite.operators import Monotonicity
@@ -84,7 +84,7 @@ def test_conserve_flux(self):
blend.sources.append(
Source(
[
- io.ComponentCube(
+ CubeComponent(
model=Image(
np.ones(observation.shape, dtype=observation.dtype),
observation.bands,
diff --git a/tests/test_observation.py b/tests/test_observation.py
index 65dc1d1..95bbe30 100644
--- a/tests/test_observation.py
+++ b/tests/test_observation.py
@@ -285,3 +285,22 @@ def test_slicing(self):
observation.noise_rms[1:4],
)
self.assertBoxEqual(sliced_observation.bbox, new_box)
+
+ def test_shallow_copy(self):
+ observation_copy = self.observation.copy()
+ self.assertObservationEqual(observation_copy, self.observation)
+
+ def test_deep_copy(self):
+ observation_copy = self.observation.copy(deep=True)
+ self.assertObservationEqual(observation_copy, self.observation)
+
+ # Modify the copy and check that the original is unchanged
+ observation_copy.images._data += 1
+ with self.assertRaises(AssertionError):
+ self.assertImageEqual(observation_copy.images, self.observation.images)
+ observation_copy.variance._data += 1
+ with self.assertRaises(AssertionError):
+ self.assertImageEqual(observation_copy.variance, self.observation.variance)
+ observation_copy.weights._data += 1
+ with self.assertRaises(AssertionError):
+ self.assertImageEqual(observation_copy.weights, self.observation.weights)
diff --git a/tests/test_parameters.py b/tests/test_parameters.py
index 043398e..cb88aa9 100644
--- a/tests/test_parameters.py
+++ b/tests/test_parameters.py
@@ -140,3 +140,82 @@ def test_fixed_parameter(self):
param = FixedParameter(x)
param.update(10, np.arange(10) * 2)
assert_array_equal(param.x, x)
+
+ def test_shallow_copy(self):
+ x = np.arange(10, dtype=float)
+
+ # FistaParameter
+ param = FistaParameter(x, 0.1)
+ param_copy = param.copy()
+ self.assertIsInstance(param_copy, FistaParameter)
+
+ assert_array_equal(param.x, param_copy.x)
+ assert_array_equal(param.helpers["z"], param_copy.helpers["z"])
+
+ # AdaproxParameter
+ param = AdaproxParameter(x, 0.1)
+ param_copy = param.copy()
+ self.assertIsInstance(param_copy, AdaproxParameter)
+
+ assert_array_equal(param.x, param_copy.x)
+ assert_array_equal(param.helpers["m"], param_copy.helpers["m"])
+ assert_array_equal(param.helpers["v"], param_copy.helpers["v"])
+ assert_array_equal(param.helpers["vhat"], param_copy.helpers["vhat"])
+
+ # FixedParameter
+ param = FixedParameter(x)
+ param_copy = param.copy()
+ self.assertIsInstance(param_copy, FixedParameter)
+ assert_array_equal(param.x, param_copy.x)
+
+ def test_deep_copy(self):
+ x = np.arange(10, dtype=float)
+
+ # FistaParameter
+ param = FistaParameter(x, 0.1)
+ param_deepcopy = param.copy(deep=True)
+ self.assertIsInstance(param_deepcopy, FistaParameter)
+
+ assert_array_equal(param.x, param_deepcopy.x)
+ param_deepcopy.x += 1
+ with self.assertRaises(AssertionError):
+ assert_array_equal(param.x, param_deepcopy.x)
+
+ assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"])
+ param_deepcopy.helpers["z"] += 1
+ with self.assertRaises(AssertionError):
+ assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"])
+
+ # AdaproxParameter
+ param = AdaproxParameter(x, 0.1)
+ param_deepcopy = param.copy(deep=True)
+ self.assertIsInstance(param_deepcopy, AdaproxParameter)
+
+ assert_array_equal(param.x, param_deepcopy.x)
+ param_deepcopy.x += 1
+ with self.assertRaises(AssertionError):
+ assert_array_equal(param.x, param_deepcopy.x)
+
+ assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"])
+ param_deepcopy.helpers["m"] = -1
+ with self.assertRaises(AssertionError):
+ assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"])
+
+ assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"])
+ param_deepcopy.helpers["v"] = -1
+ with self.assertRaises(AssertionError):
+ assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"])
+
+ assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"])
+ param_deepcopy.helpers["vhat"] = -1
+ with self.assertRaises(AssertionError):
+ assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"])
+
+ # FixedParameter
+ param = FixedParameter(x)
+ param_deepcopy = param.copy(deep=True)
+ self.assertIsInstance(param_deepcopy, FixedParameter)
+ assert_array_equal(param.x, param_deepcopy.x)
+ param_deepcopy.x += 1
+ with self.assertRaises(AssertionError):
+ assert_array_equal(param.x, param_deepcopy.x)
diff --git a/tests/test_source.py b/tests/test_source.py
index 437c6c4..591a4bc 100644
--- a/tests/test_source.py
+++ b/tests/test_source.py
@@ -27,8 +27,34 @@
class TestSource(ScarletTestCase):
- def test_constructor(self):
- # Test empty source
+ def setUp(self) -> None:
+ super().setUp()
+ self.bands = tuple("grizy")
+ self.center = (27, 32)
+ self.morph1 = integrated_circular_gaussian(sigma=0.8).astype(np.float32)
+ self.spectrum1 = np.arange(5).astype(np.float32)
+ self.component_box1 = Box((15, 15), (20, 25))
+ self.morph2 = integrated_circular_gaussian(sigma=2.1).astype(np.float32)
+ self.spectrum2 = np.arange(5)[::-1].astype(np.float32)
+ self.component_box2 = Box((15, 15), (10, 35))
+ self.components = [
+ FactorizedComponent(
+ self.bands,
+ self.spectrum1,
+ self.morph1,
+ self.component_box1,
+ self.center,
+ ),
+ FactorizedComponent(
+ self.bands,
+ self.spectrum2,
+ self.morph2,
+ self.component_box2,
+ self.center,
+ ),
+ ]
+
+ def test_empty_constructor(self):
source = Source([])
self.assertEqual(source.n_components, 0)
@@ -38,74 +64,41 @@ def test_constructor(self):
self.assertBoxEqual(source.bbox, Box((0, 0)))
self.assertTupleEqual(source.bands, ())
- # Test a source with a single component
- bands = tuple("grizy")
- center = (27, 32)
- morph1 = integrated_circular_gaussian(sigma=0.8).astype(np.float32)
- spectrum1 = np.arange(5).astype(np.float32)
- component_box1 = Box((15, 15), (20, 25))
- components = [
- FactorizedComponent(
- bands,
- spectrum1,
- morph1,
- component_box1,
- center,
- ),
- ]
- source = Source(components)
+ def test_single_component_constructor(self):
+ source = Source(self.components[:1])
self.assertEqual(source.n_components, 1)
- self.assertTupleEqual(source.center, center)
+ self.assertTupleEqual(source.center, self.center)
self.assertTupleEqual(source.source_center, (7, 7))
self.assertFalse(source.is_null)
- self.assertBoxEqual(source.bbox, component_box1)
- self.assertTupleEqual(source.bands, bands)
+ self.assertBoxEqual(source.bbox, self.component_box1)
+ self.assertTupleEqual(source.bands, self.bands)
self.assertImageEqual(
source.get_model(),
Image(
- spectrum1[:, None, None] * morph1[None, :, :],
- yx0=component_box1.origin,
- bands=bands,
+ self.spectrum1[:, None, None] * self.morph1[None, :, :],
+ yx0=self.component_box1.origin,
+ bands=self.bands,
),
)
self.assertIsNone(source.get_model(True))
self.assertEqual(source.get_model().dtype, np.float32)
+ def test_multiple_component_constructor(self):
# Test a source with multiple components
- morph2 = integrated_circular_gaussian(sigma=2.1).astype(np.float32)
- spectrum2 = np.arange(5)[::-1].astype(np.float32)
- component_box2 = Box((15, 15), (10, 35))
-
- components = [
- FactorizedComponent(
- bands,
- spectrum1,
- morph1,
- component_box1,
- center,
- ),
- FactorizedComponent(
- bands,
- spectrum2,
- morph2,
- component_box2,
- center,
- ),
- ]
- source = Source(components)
+ source = Source(self.components)
self.assertEqual(source.n_components, 2)
- self.assertTupleEqual(source.center, center)
+ self.assertTupleEqual(source.center, self.center)
self.assertTupleEqual(source.source_center, (17, 7))
self.assertFalse(source.is_null)
self.assertBoxEqual(source.bbox, Box((25, 25), (10, 25)))
- self.assertTupleEqual(source.bands, bands)
+ self.assertTupleEqual(source.bands, self.bands)
self.assertEqual(str(source), "Source<2>")
self.assertEqual(source.get_model().dtype, np.float32)
model = np.zeros((5, 25, 25), dtype=np.float32)
- model[:, 10:25, :15] = spectrum1[:, None, None] * morph1[None, :, :]
- model[:, :15, 10:25] += spectrum2[:, None, None] * morph2[None, :, :]
- model = Image(model, yx0=(10, 25), bands=tuple("grizy"))
+ model[:, 10:25, :15] = self.spectrum1[:, None, None] * self.morph1[None, :, :]
+ model[:, :15, 10:25] += self.spectrum2[:, None, None] * self.morph2[None, :, :]
+ model = Image(model, yx0=(10, 25), bands=self.bands)
self.assertImageEqual(
source.get_model(),
@@ -116,3 +109,102 @@ def test_constructor(self):
source = Source([])
result = source.get_model()
self.assertEqual(result, 0)
+
+ def test_shallow_copy(self):
+ source = Source(self.components)
+ source_copy = source.copy()
+
+ self.assertIsNot(source, source_copy)
+ self.assertEqual(source.n_components, 2)
+ self.assertEqual(source.n_components, source_copy.n_components)
+ self.assertFactorizedComponentEqual(source.components[0], source_copy.components[0])
+ self.assertFactorizedComponentEqual(source.components[1], source_copy.components[1])
+ self.assertIs(source.flux_weighted_image, source_copy.flux_weighted_image)
+ self.assertIs(source.metadata, source_copy.metadata)
+
+ def test_deepcopy(self):
+ source = Source(self.components)
+ source_deepcopy = source.copy(deep=True)
+
+ self.assertIsNot(source, source_deepcopy)
+ self.assertEqual(source.n_components, source_deepcopy.n_components)
+ for comp, comp_deepcopy in zip(source.components, source_deepcopy.components):
+ self.assertIsNot(comp, comp_deepcopy)
+ self.assertFactorizedComponentEqual(comp, comp_deepcopy)
+ comp_deepcopy._spectrum.x += 1
+ with self.assertRaises(AssertionError):
+ np.testing.assert_array_equal(comp._spectrum.x, comp_deepcopy._spectrum.x)
+ comp_deepcopy._morph.x += 1
+ with self.assertRaises(AssertionError):
+ np.testing.assert_array_equal(comp._morph.x, comp_deepcopy._morph.x)
+
+ def test_slice(self):
+ source = Source(self.components)
+ source_sliced = source["g":"r"]
+ self.assertTupleEqual(source_sliced.bands, ("g", "r"))
+ self.assertEqual(source.n_components, source_sliced.n_components)
+
+ for comp, comp_sliced in zip(source.components, source_sliced.components):
+ self.assertFactorizedComponentEqual(comp["g":"r"], comp_sliced)
+
+ def test_reorder(self):
+ source = Source(self.components)
+ indices = ("i", "g", "r")
+ source_reordered = source[indices]
+ self.assertTupleEqual(source_reordered.bands, indices)
+ self.assertEqual(source.n_components, source_reordered.n_components)
+ for comp, comp_reordered in zip(source.components, source_reordered.components):
+ self.assertFactorizedComponentEqual(comp[indices], comp_reordered)
+
+ source_reordered = source["igr"]
+ self.assertTupleEqual(source_reordered.bands, indices)
+ self.assertEqual(source.n_components, source_reordered.n_components)
+ for comp, comp_reordered in zip(source.components, source_reordered.components):
+ self.assertFactorizedComponentEqual(comp["igr"], comp_reordered)
+
+ def test_subset(self):
+ source = Source(self.components)
+ source_subset = source[("r",)]
+ self.assertTupleEqual(source_subset.bands, ("r",))
+ self.assertEqual(source.n_components, source_subset.n_components)
+ for comp, comp_subset in zip(source.components, source_subset.components):
+ self.assertFactorizedComponentEqual(comp["r"], comp_subset)
+
+ source = source.copy(deep=True)
+ for comp in source.components:
+ comp._bands = ("ab", "cd", "ef")
+ source_subset = source["ab"]
+ self.assertTupleEqual(source_subset.bands, ("ab",))
+ self.assertEqual(source.n_components, source_subset.n_components)
+ for comp, comp_subset in zip(source.components, source_subset.components):
+ self.assertFactorizedComponentEqual(comp["ab"], comp_subset)
+
+ def test_indexing_errors(self):
+ source = Source(self.components)
+
+ with self.assertRaises(IndexError):
+ source["x"]
+
+ with self.assertRaises(IndexError):
+ source["r":"x"]
+
+ with self.assertRaises(IndexError):
+ source["x":"i"]
+
+ with self.assertRaises(IndexError):
+ source["g", "x", "i"]
+
+ with self.assertRaises(IndexError):
+ source[Box((0, 0), (10, 10))]
+
+ with self.assertRaises(IndexError):
+ source[:, 10:20, 10:20]
+
+ with self.assertRaises(IndexError):
+ source[1:]
+
+ with self.assertRaises(IndexError):
+ source[1]
+
+ with self.assertRaises(IndexError):
+ source[0, 1]
diff --git a/tests/utils.py b/tests/utils.py
index 9ef72e3..c44e67b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,13 +21,15 @@
import sys
import traceback
-from typing import Sequence
+from typing import Sequence, cast
from unittest import TestCase
import numpy as np
from lsst.scarlet.lite.bbox import Box
+from lsst.scarlet.lite.component import FactorizedComponent
from lsst.scarlet.lite.fft import match_kernel
from lsst.scarlet.lite.image import Image
+from lsst.scarlet.lite.source import Source
from lsst.scarlet.lite.utils import integrated_circular_gaussian
from numpy.testing import assert_almost_equal, assert_array_equal
from numpy.typing import DTypeLike
@@ -200,3 +202,38 @@ def assertImageAlmostEqual(self, image: Image, truth: Image, decimal: int = 7):
def assertImageEqual(self, image: Image, truth: Image): # noqa: N802
self.assertImageAlmostEqual(image, truth)
assert_array_equal(image.data, truth.data)
+
+ def assertFactorizedComponentEqual( # noqa: N802
+ self,
+ component: FactorizedComponent,
+ truth: FactorizedComponent,
+ ):
+ self.assertTupleEqual(component.bands, truth.bands)
+ self.assertTupleEqual(component.peak, truth.peak)
+ np.testing.assert_array_equal(component._spectrum.x, truth._spectrum.x)
+ np.testing.assert_array_equal(component._morph.x, truth._morph.x)
+ self.assertBoxEqual(component.bbox, truth.bbox)
+ self.assertEqual(component.bg_rms, truth.bg_rms)
+ self.assertEqual(component.bg_thresh, truth.bg_thresh)
+ self.assertEqual(component.floor, truth.floor)
+ self.assertEqual(component.padding, truth.padding)
+ self.assertEqual(component.is_symmetric, truth.is_symmetric)
+
+ def assertSourceEqual(self, source: Source, truth: Source): # noqa: N802
+ self.assertEqual(source.n_components, truth.n_components)
+ self.assertBoxEqual(source.bbox, truth.bbox)
+ self.assertTupleEqual(source.bands, truth.bands)
+ for comp, comp_truth in zip(source.components, truth.components):
+ self.assertFactorizedComponentEqual(
+ cast(FactorizedComponent, comp),
+ cast(FactorizedComponent, comp_truth),
+ )
+
+ def assertObservationEqual(self, obs: ObservationData, truth: ObservationData): # noqa: N802
+ self.assertImageEqual(obs.images, truth.images)
+ self.assertImageEqual(obs.variance, truth.variance)
+ self.assertImageEqual(obs.weights, truth.weights)
+ assert_array_equal(obs.psfs, truth.psfs)
+ assert_array_equal(obs.model_psf, truth.model_psf)
+ assert_array_equal(obs.noise_rms, truth.noise_rms)
+ self.assertBoxEqual(obs.bbox, truth.bbox)