diff --git a/doc/changes/DM-47738.feature.rst b/doc/changes/DM-47738.feature.rst deleted file mode 100644 index d96cacf..0000000 --- a/doc/changes/DM-47738.feature.rst +++ /dev/null @@ -1,3 +0,0 @@ -* Added ``Image.trimmed`` method to remove data below a threshold from an Image. -* Added ``Image.at`` method to extract a single pixel from an Image. -* Added slicing for ``Observation`` to slice along spectral or spatial dimension \ No newline at end of file diff --git a/doc/changes/DM-49537.md b/doc/changes/DM-49537.md deleted file mode 100644 index 33d268a..0000000 --- a/doc/changes/DM-49537.md +++ /dev/null @@ -1 +0,0 @@ -Improved the storage of sources and components using a registry. This nmakes it easier for end users to persist custom component types. \ No newline at end of file diff --git a/doc/lsst.scarlet.lite/changes.rst b/doc/lsst.scarlet.lite/changes.rst new file mode 100644 index 0000000..2c2f5bf --- /dev/null +++ b/doc/lsst.scarlet.lite/changes.rst @@ -0,0 +1,38 @@ +.. _lsst.scarlet.lite-changes: + +================= +v30.0.0 Changes +================= + +Improved Slicing +---------------- + +In v29.0 slicing was only supported for the ``Image`` class. In ``v30.0`` slicing has been extended to ``Blend``, ``Source``, ``Component``, and ``Observation`` classes. This allows uers to use a subset of bands or change the band order of each of these classes. ``Observation`` can be sliced along the spatial dimension as well. + +New Image methods +----------------- +- ``Image.trimmed`` method added to remove data below a threshold from an Image. +- ``Image.at`` method added to extract a single pixel from an Image. + +Serialization Improvments +------------------------- +Serialization received a major update in ``v30.0`` to improve performance and usability. The major upgrade is a ``Migration Registry`` that registers all scarlet serializable classes and allows users to register their own custom classes along with function to migrate between different versions of the class. +This allows scarlet to automatically handle versioning of serialized objects and migrate them to the latest version when deserializing. + +As part of this update the classes used for serialization were expanded to included base classes such as ``BlendBaseData``, ``SourceBaseData``, and ``ComponentBaseData`` to make it easier for users to extend serialization to their own custom classes. + +Copying and Deep Copying +------------------------ +To support the serialization improvements, ``__copy__`` and ``__deepcopy__`` methods were added to nearly all scarlet lite classes. This allows users to create copies of scarlet objects using the standard library ``copy`` and ``deepcopy`` modules. + +Initialization Changes +---------------------- +Initialization has a few changes to make it both more customizable and more standardized +- The ineffective ``FactorizedWaveletInitialization`` was deprecated and the ``FactorizedChirInitialization`` was changed to ``FactorizedInitialization`` as testing has shown that it is a superior algorithm. +- A large number of new parameters were added to ``FactorizedInitialization`` to make it easier for users to initialize sources using more tunable constraints and detection images that may differ from the observation image. + +Other Changes +------------- +- The detection algorithms have been updated and while not ready for production users should see improved performance. +- A more customizable ``conserve_flux`` function was added to ``scarlet`` +- Positions are rounded instead of truncated during initialization to remove a bias from inaccurate sources positions. diff --git a/doc/lsst.scarlet.lite/index.rst b/doc/lsst.scarlet.lite/index.rst index c510f93..201227c 100644 --- a/doc/lsst.scarlet.lite/index.rst +++ b/doc/lsst.scarlet.lite/index.rst @@ -20,6 +20,7 @@ toctree linking to topics related to using the module's APIs. getting_started detection + changes .. _lsst.scarlet.lite-contributing: diff --git a/pyproject.toml b/pyproject.toml index a3041be..8d181df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ test = [ "pytest >= 3.2", ] yaml = ["pyyaml >= 5.1"] -plotting = ["matplotlib", "astropy < 7"] +plotting = ["matplotlib", "astropy >= 6.1"] [tool.setuptools.packages.find] where = ["python"] diff --git a/python/lsst/scarlet/lite/bbox.py b/python/lsst/scarlet/lite/bbox.py index f861673..73cb04c 100644 --- a/python/lsst/scarlet/lite/bbox.py +++ b/python/lsst/scarlet/lite/bbox.py @@ -23,7 +23,8 @@ __all__ = ["Box", "overlapped_slices"] -from typing import Sequence, cast +from copy import deepcopy +from typing import Any, Sequence, cast import numpy as np @@ -462,6 +463,15 @@ def __matmul__(self, bbox: Box) -> Box: result = Box.from_bounds(*bounds) return result + def __deepcopy__(self, memo: dict[int, Any]) -> Box: + """Deep copy of the box""" + my_id = id(self) + if my_id in memo: + return memo[my_id] + result = Box(deepcopy(self.shape), origin=deepcopy(self.origin)) + memo[my_id] = result + return result + def __copy__(self) -> Box: """Copy of the box""" return Box(self.shape, origin=self.origin) diff --git a/python/lsst/scarlet/lite/blend.py b/python/lsst/scarlet/lite/blend.py index e8610ab..4919585 100644 --- a/python/lsst/scarlet/lite/blend.py +++ b/python/lsst/scarlet/lite/blend.py @@ -24,7 +24,8 @@ __all__ = ["Blend"] from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Sequence, cast +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Callable, Self, Sequence, cast import numpy as np @@ -32,7 +33,7 @@ from .component import Component, FactorizedComponent from .image import Image from .observation import Observation -from .source import Source +from .source import Source, SourceBase if TYPE_CHECKING: from .io import ScarletBlendData, ScarletSourceBaseData @@ -57,7 +58,7 @@ class BlendBase(ABC): Additional metadata to store with the blend. """ - sources: list[Source] + sources: Sequence[SourceBase] observation: Observation metadata: dict | None @@ -80,6 +81,72 @@ def components(self) -> list[Component]: """ return [c for src in self.sources for c in src.components] + @abstractmethod + def __getitem__(self, indices: Any) -> Self: + """Get a sub-blend corresponding to the given indices. + + Parameters + ---------- + indices : + The indices to use to slice the blend. + + Returns + ------- + sub_blend : + A new `BlendBase` instance containing only data from the + specified bands in the specified order. + + Raises + ------ + IndexError : + If the indices contain bands not included in the original + blend or any spatial indices are given. + """ + + @abstractmethod + def __copy__(self) -> Self: + """Create a copy of this blend. + + Returns + ------- + blend : BlendBase + A new blend that is a copy of this one. + """ + + @abstractmethod + def __deepcopy__(self, memo: dict[int, Any]) -> Self: + """Create a deep copy of this blend. + + Parameters + ---------- + memo : dict[int, Any] + A memoization dictionary used by `copy.deepcopy`. + + Returns + ------- + blend : BlendBase + A new blend that is a deep copy of this one. + """ + + def copy(self, deep: bool = False) -> Self: + """Create a copy of this blend. + + Parameters + ---------- + deep : + If `True`, a deep copy is made. If `False`, a shallow copy is made. + Default is `False`. + + Returns + ------- + blend : Self + A new blend that is a copy of this one. + """ + if deep: + return self.__deepcopy__({}) + else: + return self.__copy__() + @abstractmethod def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image: """Generate a model of the entire blend. @@ -128,6 +195,8 @@ class Blend(BlendBase): Additional metadata to store with the blend. """ + sources: list[Source] + def __init__(self, sources: Sequence[Source], observation: Observation, metadata: dict | None = None): self.sources = list(sources) self.observation = observation @@ -433,3 +502,69 @@ def to_data(self) -> ScarletBlendData: ) return blend_data + + def __getitem__(self, indices: Any) -> Blend: + """Get a sub-blend corresponding to the given indices. + + Parameters + ---------- + indices : + The indices to use to slice the blend. + + Returns + ------- + blend : + A new `Blend` instance containing only data from the + specified bands in the specified order. + + Raises + ------ + IndexError : + If the indices contain bands not included in the original + blend or a bounding box is given. + """ + return Blend( + sources=[src[indices] for src in self.sources], + observation=self.observation[indices], + metadata=self.metadata, + ) + + def __copy__(self) -> Blend: + """Create a copy of this blend. + + Returns + ------- + blend : Blend + A new blend that is a copy of this one. + """ + return Blend(sources=self.sources, observation=self.observation, metadata=self.metadata) + + def __deepcopy__(self, memo: dict[int, Any]) -> Blend: + """Create a deep copy of this blend. + + Parameters + ---------- + memo : dict[int, Any] + A memoization dictionary used by `copy.deepcopy`. + + Returns + ------- + blend : Blend + A new blend that is a deep copy of this one. + """ + # Check if already copied + if id(self) in memo: + return memo[id(self)] + + # Create placeholder and add to memo FIRST + blend = Blend.__new__(Blend) + memo[id(self)] = blend + + # Now safely initialize the placeholder with deepcopied arguments + blend.__init__( # type: ignore[misc] + sources=[deepcopy(src, memo) for src in self.sources], + observation=deepcopy(self.observation, memo), + metadata=deepcopy(self.metadata, memo), + ) + + return blend diff --git a/python/lsst/scarlet/lite/component.py b/python/lsst/scarlet/lite/component.py index 81f109d..0e02d83 100644 --- a/python/lsst/scarlet/lite/component.py +++ b/python/lsst/scarlet/lite/component.py @@ -18,11 +18,13 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . - from __future__ import annotations +from copy import deepcopy + __all__ = [ "Component", + "CubeComponent", "FactorizedComponent", "default_fista_parameterization", "default_adaprox_parameterization", @@ -30,7 +32,7 @@ from abc import ABC, abstractmethod from functools import partial -from typing import TYPE_CHECKING, Callable, cast +from typing import TYPE_CHECKING, Any, Callable, cast import numpy as np @@ -38,9 +40,14 @@ from .image import Image from .operators import Monotonicity, prox_uncentered_symmetry from .parameters import AdaproxParameter, FistaParameter, Parameter, parameter, relative_step +from .utils import convert_indices if TYPE_CHECKING: - from .io import ScarletComponentBaseData + from .io import ScarletComponentBaseData, ScarletCubeComponentData + +import logging + +Logger = logging.getLogger(__name__) class Component(ABC): @@ -126,6 +133,64 @@ def to_data(self) -> ScarletComponentBaseData: The data object containing the component information """ + @abstractmethod + def __getitem__(self, indices: Any) -> Component: + """Get a sub-component corresponding to the given indices. + + Parameters + ---------- + indices: Any + The indices to use to slice the component model. + + Returns + ------- + sub_component: Component + A new component that is a sub-component of this one. + + Raises + ------ + IndexError : + If the index includes a ``Box`` or spatial indices. + """ + + @abstractmethod + def __copy__(self) -> Component: + """Create a copy of this component. + + Returns + ------- + component : Component + A new component that is a copy of this one. + """ + + @abstractmethod + def __deepcopy__(self, memo: dict[int, Any]) -> Component: + """Create a deep copy of this component. + + Returns + ------- + component : Component + A new component that is a deep copy of this one. + """ + + def copy(self, deep: bool = False) -> Component: + """Create a copy of this component. + + Parameters + ---------- + deep : bool, optional + If `True`, a deep copy is made. If `False`, a shallow copy is made. + Default is `False`. + + Returns + ------- + component : Component + A new component that is a copy of this one. + """ + if deep: + return self.__deepcopy__({}) + return self.__copy__() + class FactorizedComponent(Component): """A component that can be factorized into spectrum and morphology @@ -395,6 +460,234 @@ def __str__(self): def __repr__(self): return self.__str__() + def __getitem__(self, indices: Any) -> FactorizedComponent: + """Get a sub-component corresponding to the given indices. + + Parameters + ---------- + indices: Any + The indices to use to slice the component model. + + Returns + ------- + component: FactorizedComponent + A new component that is a sub-component of this one. + + Raises + ------ + IndexError : + If the index includes a ``Box`` or spatial indices. + """ + # Convert the band indices into numerical indices + band_indices = convert_indices(self.bands, indices) + if isinstance(band_indices, slice): + bands = self.bands[band_indices] + else: + bands = tuple(self.bands[i] for i in band_indices) + + # Slice the spectrum + spectrum = self._spectrum.x[band_indices,] + + return FactorizedComponent( + bands=bands, + spectrum=spectrum, + morph=self.morph, + bbox=self.bbox, + peak=self.peak, + bg_rms=self.bg_rms, + bg_thresh=self.bg_thresh, + floor=self.floor, + monotonicity=self.monotonicity, + padding=self.padding, + is_symmetric=self.is_symmetric, + ) + + def __deepcopy__(self, memo: dict[int, Any]) -> FactorizedComponent: + """Create a deep copy of this component. + + Parameters + ---------- + memo: dict[int, Any] + The memoization dictionary used by `copy.deepcopy`. + + Returns + ------- + component : FactorizedComponent + A new component that is a deep copy of this one. + """ + # Check if already copied + if id(self) in memo: + return memo[id(self)] + + # Create placeholder and add to memo FIRST + component = FactorizedComponent.__new__(FactorizedComponent) + memo[id(self)] = component + + # Now safely initialize the placeholder with deepcopied arguments + component.__init__( # type: ignore[misc] + bands=deepcopy(self.bands, memo), + spectrum=deepcopy(self.spectrum, memo), + morph=deepcopy(self.morph, memo), + bbox=deepcopy(self.bbox, memo), + peak=deepcopy(self.peak, memo), + bg_rms=deepcopy(self.bg_rms, memo), + bg_thresh=self.bg_thresh, + floor=self.floor, + monotonicity=deepcopy(self.monotonicity, memo), + padding=self.padding, + is_symmetric=self.is_symmetric, + ) + return component + + def __copy__(self) -> FactorizedComponent: + """Create a copy of this component. + + Returns + ------- + component : FactorizedComponent + A new component that is a shallow copy of this one. + """ + return FactorizedComponent( + bands=self.bands, + spectrum=self.spectrum, + morph=self.morph, + bbox=self.bbox, + peak=self.peak, + bg_rms=self.bg_rms, + bg_thresh=self.bg_thresh, + floor=self.floor, + monotonicity=self.monotonicity, + padding=self.padding, + is_symmetric=self.is_symmetric, + ) + + +class CubeComponent(Component): + """Dummy component for a component cube. + + This is duck-typed to a `lsst.scarlet.lite.Component` in order to + generate a model from the component but it is currently not functional + in that it cannot be optimized, only persisted and loaded. + + If scarlet lite ever implements a component as a data cube, + this class can be removed. + """ + + def __init__(self, model: Image, peak: tuple[int, int]): + """Initialization + + Parameters + ---------- + bands : + model : + The 3D (bands, y, x) model of the component. + peak : + The `(y, x)` peak of the component. + bbox : + The bounding box of the component. + """ + super().__init__(model.bands, model.bbox) + self._model = model + self.peak = peak + + def get_model(self) -> Image: + """Generate the model for the source + + Returns + ------- + model : + The model as a 3D `(band, y, x)` array. + """ + return self._model + + def resize(self, model_box: Box) -> bool: + """Resize the component if needed and return whether it was resized""" + Logger.warning("CubeComponent does not support resizing") + return False + + def update(self, it: int, input_grad: np.ndarray) -> None: + """Implementation of unused abstract method""" + Logger.warning("CubeComponent does not support updates") + + def parameterize(self, parameterization: Callable) -> None: + """Implementation of unused abstract method""" + Logger.warning("CubeComponent does not support parameterization") + + def to_data(self) -> ScarletCubeComponentData: + """Convert the component to persistable ScarletComponentData + + Returns + ------- + component_data: ScarletComponentData + The data object containing the component information + """ + from .io import ScarletCubeComponentData + + return ScarletCubeComponentData( + origin=self.bbox.origin, # type: ignore + peak=self.peak, # type: ignore + model=self.get_model().data, + ) + + def __getitem__(self, indices: Any) -> CubeComponent: + """Get a sub-component corresponding to the given indices. + + Parameters + ---------- + indices : + The indices to select. + Returns + ------- + sub_component : + A new component that is a sub-component of this one. + """ + band_indices = convert_indices(self.bands, indices) + if isinstance(band_indices, slice): + bands = self.bands[band_indices] + else: + bands = tuple(self.bands[i] for i in band_indices) + + data = self.get_model()._data[band_indices,] + model = Image(data=data, bands=bands, yx0=cast(tuple[int, int], self.bbox.origin)) + return CubeComponent(model=model, peak=self.peak) + + def __copy__(self) -> CubeComponent: + """Create a copy of this component. + + Returns + ------- + component : ComponentCube + A new component that is a shallow copy of this one. + """ + return CubeComponent(model=self._model, peak=self.peak) + + def __deepcopy__(self, memo: dict[int, Any]) -> CubeComponent: + """Create a deep copy of this component. + + Parameters + ---------- + memo: dict[int, Any] + The memoization dictionary used by `copy.deepcopy`. + + Returns + ------- + component : ComponentCube + A new component that is a deep copy of this one. + """ + if id(self) in memo: + return memo[id(self)] + + # Create placeholder and add to memo FIRST + component = CubeComponent.__new__(CubeComponent) + memo[id(self)] = component + + # Now safely initialize the placeholder with deepcopied arguments + component.__init__( # type: ignore[misc] + model=self._model.copy(), + peak=self.peak, + ) + return component + def default_fista_parameterization(component: Component): """Initialize a factorized component to use FISTA PGM for optimization""" diff --git a/python/lsst/scarlet/lite/image.py b/python/lsst/scarlet/lite/image.py index c692dc0..461079b 100644 --- a/python/lsst/scarlet/lite/image.py +++ b/python/lsst/scarlet/lite/image.py @@ -22,13 +22,14 @@ from __future__ import annotations import operator +from copy import deepcopy from typing import Any, Callable, Sequence, cast import numpy as np from numpy.typing import DTypeLike from .bbox import Box -from .utils import ScalarLike, ScalarTypes +from .utils import ScalarLike, ScalarTypes, convert_indices __all__ = ["Image", "MismatchedBoxError", "MismatchedBandsError"] @@ -54,7 +55,7 @@ def get_dtypes(*data: np.ndarray | Image | ScalarLike) -> list[DTypeLike]: result: A list of datatypes. """ - dtypes: list[DTypeLike] = [None] * len(data) + dtypes: list[DTypeLike] = [float] * len(data) for d, element in enumerate(data): if hasattr(element, "dtype"): dtypes[d] = cast(np.ndarray, element).dtype @@ -496,24 +497,7 @@ def spectral_indices(self, bands: Sequence | slice) -> tuple[int, ...] | slice: band_indices: Tuple of indices for each band in this image. """ - if isinstance(bands, slice): - # Convert a slice of band names into a slice of array indices - # to select the appropriate slice. - if bands.start is None: - start = None - else: - start = self.bands.index(bands.start) - if bands.stop is None: - stop = None - else: - stop = self.bands.index(bands.stop) + 1 - return slice(start, stop, bands.step) - - if isinstance(bands, str): - return (self.bands.index(bands),) - - band_indices = tuple(self.bands.index(band) for band in bands if band in self.bands) - return band_indices + return convert_indices(self.bands, bands) def matched_spectral_indices( self, @@ -545,7 +529,8 @@ def matched_spectral_indices( err = "Attempted to insert a multi-band image into a monochromatic image" raise ValueError(err) - self_indices = cast(tuple[int, ...], self.spectral_indices(other.bands)) + common_bands = tuple(set(self.bands).intersection(set(other.bands))) + self_indices = cast(tuple[int, ...], self.spectral_indices(common_bands)) matched_bands = tuple(self.bands[bidx] for bidx in self_indices) other_indices = cast(tuple[int, ...], other.spectral_indices(matched_bands)) return other_indices, self_indices @@ -682,6 +667,45 @@ def repeat(self, bands: tuple) -> Image: yx0=self.yx0, ) + def __copy__(self) -> Image: + """Make a copy of this image. + + Returns + ------- + image: Image + The copy of this image. + """ + return self.copy_with() + + def __deepcopy__(self, memo: dict[int, Any]) -> Image: + """Make a deep copy of this image. + + Parameters + ---------- + memo: + A dictionary of already copied objects to avoid infinite recursion. + Returns + ------- + image: Image + The deep copy of this image. + """ + # Check if already copied + if id(self) in memo: + return memo[id(self)] + + # Create placeholder and add to memo FIRST + result = Image.__new__(Image) + memo[id(self)] = result + + # Now safely initialize the placeholder with deepcopied arguments + result.__init__( # type: ignore[misc] + data=deepcopy(self.data, memo), + bands=deepcopy(self.bands, memo), + yx0=deepcopy(self.yx0, memo), + ) + + return result + def copy(self, order=None) -> Image: """Make a copy of this image. diff --git a/python/lsst/scarlet/lite/io/blend.py b/python/lsst/scarlet/lite/io/blend.py index 9a2db02..2f360e1 100644 --- a/python/lsst/scarlet/lite/io/blend.py +++ b/python/lsst/scarlet/lite/io/blend.py @@ -6,6 +6,7 @@ from typing import Any import numpy as np +from deprecated.sphinx import deprecated # type: ignore from numpy.typing import DTypeLike from ..bbox import Box @@ -170,6 +171,11 @@ def to_blend(self, observation: Observation) -> Blend: return Blend(sources=sources, observation=observation, metadata=self.metadata) @staticmethod + @deprecated( + reason="ScarletBlendData.from_blend is deprecated. Use blend.to_data() instead.", + version="v30.0", + category=FutureWarning, + ) def from_blend(blend: Blend) -> ScarletBlendData: """Deprecated: Convert a scarlet lite blend into a storage data model. @@ -182,7 +188,6 @@ def from_blend(blend: Blend) -> ScarletBlendData: result : The storage data model representing the blend. """ - logger.warning("ScarletBlendData.from_blend is deprecated. Use blend.to_data() instead.") return blend.to_data() diff --git a/python/lsst/scarlet/lite/io/blend_base.py b/python/lsst/scarlet/lite/io/blend_base.py index 82b4dc5..7d71ce6 100644 --- a/python/lsst/scarlet/lite/io/blend_base.py +++ b/python/lsst/scarlet/lite/io/blend_base.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Any, ClassVar +import numpy as np from numpy.typing import DTypeLike from ..bbox import Box @@ -54,7 +55,7 @@ def as_dict(self) -> dict: """ @staticmethod - def from_dict(data: dict, dtype: DTypeLike | None = None) -> ScarletBlendBaseData: + def from_dict(data: dict, dtype: DTypeLike = np.float32) -> ScarletBlendBaseData: """Reconstruct `ScarletBlendBaseData` from JSON compatible dict. Parameters diff --git a/python/lsst/scarlet/lite/io/component.py b/python/lsst/scarlet/lite/io/component.py index f67541e..223e64d 100644 --- a/python/lsst/scarlet/lite/io/component.py +++ b/python/lsst/scarlet/lite/io/component.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import ClassVar +import numpy as np from numpy.typing import DTypeLike from ..component import Component @@ -62,7 +63,7 @@ def as_dict(self) -> dict: """ @staticmethod - def from_dict(data: dict, dtype: DTypeLike | None = None) -> ScarletComponentBaseData: + def from_dict(data: dict, dtype: DTypeLike = np.float32) -> ScarletComponentBaseData: """Reconstruct `ScarletComponentBaseData` from JSON compatible dict. diff --git a/python/lsst/scarlet/lite/io/cube_component.py b/python/lsst/scarlet/lite/io/cube_component.py index 9d2acc9..3cd0f47 100644 --- a/python/lsst/scarlet/lite/io/cube_component.py +++ b/python/lsst/scarlet/lite/io/cube_component.py @@ -1,13 +1,14 @@ from __future__ import annotations +import logging from dataclasses import dataclass -from typing import Callable import numpy as np +from deprecated.sphinx import deprecated # type: ignore from numpy.typing import DTypeLike from ..bbox import Box -from ..component import Component +from ..component import CubeComponent from ..image import Image from ..observation import Observation from .component import ScarletComponentBaseData @@ -19,68 +20,20 @@ COMPONENT_TYPE = "cube" MigrationRegistry.set_current(COMPONENT_TYPE, CURRENT_SCHEMA) +logger = logging.getLogger(__name__) -class ComponentCube(Component): - """Dummy component for a component cube. - This is duck-typed to a `lsst.scarlet.lite.Component` in order to - generate a model from the component but it is currently not functional - in that it cannot be optimized, only persisted and loaded. - - If scarlet lite ever implements a component as a data cube, - this class can be removed. - """ +@deprecated( + reason="ComponentCube is deprecated and will be removed after scarlet_lite v30.0. " + "Please use CubeComponent instead.", + version="scarlet_lite v30.0", + category=FutureWarning, +) +class ComponentCube(CubeComponent): + """Deprecated, use CubeComponent instead.""" def __init__(self, model: Image, peak: tuple[int, int]): - """Initialization - - Parameters - ---------- - bands : - model : - The 3D (bands, y, x) model of the component. - peak : - The `(y, x)` peak of the component. - bbox : - The bounding box of the component. - """ - super().__init__(model.bands, model.bbox) - self._model = model - self.peak = peak - - def get_model(self) -> Image: - """Generate the model for the source - - Returns - ------- - model : - The model as a 3D `(band, y, x)` array. - """ - return self._model - - def resize(self, model_box: Box) -> bool: - """Test whether or not the component needs to be resized""" - return False - - def update(self, it: int, input_grad: np.ndarray) -> None: - """Implementation of unused abstract method""" - - def parameterize(self, parameterization: Callable) -> None: - """Implementation of unused abstract method""" - - def to_data(self) -> ScarletCubeComponentData: - """Convert the component to persistable ScarletComponentData - - Returns - ------- - component_data: ScarletComponentData - The data object containing the component information - """ - return ScarletCubeComponentData( - origin=self.bbox.origin, # type: ignore - peak=self.peak, # type: ignore - model=self.get_model().data, - ) + super().__init__(model=model, peak=peak) @dataclass(kw_only=True) @@ -110,7 +63,7 @@ class ScarletCubeComponentData(ScarletComponentBaseData): def shape(self): return self.model.shape[-2:] - def to_component(self, observation: Observation) -> ComponentCube: + def to_component(self, observation: Observation) -> CubeComponent: """Convert the storage data model into a scarlet Component Parameters @@ -130,7 +83,7 @@ def to_component(self, observation: Observation) -> ComponentCube: else: peak = (int(np.round(self.peak[0])), int(np.round(self.peak[0]))) assert peak is not None - component = ComponentCube( + component = CubeComponent( model=Image(model, yx0=bbox.origin, bands=observation.bands), # type: ignore peak=peak, ) diff --git a/python/lsst/scarlet/lite/io/source.py b/python/lsst/scarlet/lite/io/source.py index 5078c86..596b9a9 100644 --- a/python/lsst/scarlet/lite/io/source.py +++ b/python/lsst/scarlet/lite/io/source.py @@ -5,6 +5,7 @@ from typing import Any import numpy as np +from deprecated.sphinx import deprecated # type: ignore from numpy.typing import DTypeLike from ..component import Component @@ -104,6 +105,11 @@ def to_source(self, observation: Observation) -> Source: return Source(components=components, metadata=self.metadata) @staticmethod + @deprecated( + reason="from_source is deprecated and will be removed in a future release.", + version="v30.0", + category=FutureWarning, + ) def from_source(source: Source) -> ScarletSourceData: """Deprecated: Create a `ScarletSourceData` from a scarlet `Source` @@ -117,7 +123,6 @@ def from_source(source: Source) -> ScarletSourceData: result: The `ScarletSourceData` representation of the source. """ - logger.warning("from_source is deprecated and will be removed in a future release.") return source.to_data() diff --git a/python/lsst/scarlet/lite/models/free_form.py b/python/lsst/scarlet/lite/models/free_form.py index 278e105..a1b1202 100644 --- a/python/lsst/scarlet/lite/models/free_form.py +++ b/python/lsst/scarlet/lite/models/free_form.py @@ -18,20 +18,25 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations __all__ = ["FactorizedFreeFormComponent"] -from typing import Callable, cast +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Callable, cast import numpy as np -from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore from ..bbox import Box from ..component import Component, FactorizedComponent from ..detect import footprints_to_image +from ..detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore from ..image import Image from ..parameters import Parameter, parameter +if TYPE_CHECKING: + from ..io.component import ScarletComponentBaseData + class FactorizedFreeFormComponent(FactorizedComponent): """Implements a free-form component @@ -238,3 +243,90 @@ def __str__(self): def __repr__(self): return self.__str__() + + def to_data(self) -> ScarletComponentBaseData: + raise NotImplementedError("Serialization not implemented for FreeFormComponent") + + def __getitem__(self, indices: Any) -> FreeFormComponent: + """Get a sub-component corresponding to the given indices. + + Parameters + ---------- + indices: Any + The indices to use to slice the component model. + + Returns + ------- + component: FreeFormComponent + A new component that is a sub-component of this one. + + Raises + ------ + IndexError : + If the index includes a ``Box`` or spatial indices. + """ + if indices in self.bands: + bands = (indices,) + else: + bands = tuple(indices) + + return FreeFormComponent( + bands=bands, + model=self.model[indices], + model_bbox=self.bbox, + bg_thresh=self.bg_thresh, + bg_rms=self.bg_rms, + floor=self.floor, + peaks=self.peaks, + min_area=self.min_area, + ) + + def __deepcopy__(self, memo: dict[int, Any]) -> FreeFormComponent: + """Create a deep copy of this component. + + Parameters + ---------- + memo: dict[int, Any] + A dictionary to keep track of already copied objects. + + Returns + ------- + component : FreeFormComponent + A new component that is a deep copy of this one. + """ + if id(self) in memo: + return memo[id(self)] + + component = FreeFormComponent.__new__(FreeFormComponent) + memo[id(self)] = component + + component.__init__( # type: ignore[misc] + bands=deepcopy(self.bands), + model=deepcopy(self.model), + model_bbox=deepcopy(self.bbox), + bg_thresh=self.bg_thresh, + bg_rms=deepcopy(self.bg_rms), + floor=self.floor, + peaks=deepcopy(self.peaks), + min_area=self.min_area, + ) + return component + + def __copy__(self) -> FreeFormComponent: + """Create a copy of this component. + + Returns + ------- + component : FreeFormComponent + A new component that is a copy of this one. + """ + return FreeFormComponent( + bands=self.bands, + model=self.model, + model_bbox=self.bbox, + bg_thresh=self.bg_thresh, + bg_rms=self.bg_rms, + floor=self.floor, + peaks=self.peaks, + min_area=self.min_area, + ) diff --git a/python/lsst/scarlet/lite/models/parametric.py b/python/lsst/scarlet/lite/models/parametric.py index ed05f1e..246bacc 100644 --- a/python/lsst/scarlet/lite/models/parametric.py +++ b/python/lsst/scarlet/lite/models/parametric.py @@ -37,7 +37,8 @@ "EllipticalParametricComponent", ] -from typing import TYPE_CHECKING, Callable, Sequence, cast +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Callable, Sequence, cast import numpy as np from scipy.special import erf @@ -47,6 +48,7 @@ from ..component import Component from ..image import Image from ..parameters import Parameter, parameter +from ..utils import convert_indices if TYPE_CHECKING: from ..io import ScarletComponentBaseData @@ -842,6 +844,105 @@ def parameterize(self, parameterization: Callable) -> None: def to_data(self) -> ScarletComponentBaseData: raise NotImplementedError("Saving elliptical parametric components is not yet implemented") + def __getitem__(self, indices: Any) -> ParametricComponent: + """Get a sub-component corresponding to the given indices. + + Parameters + ---------- + indices: Any + The indices to use to slice the component model. + + Returns + ------- + component: ParametricComponent + A new component that is a sub-component of this one. + + Raises + ------ + IndexError : + If the index includes a ``Box`` or spatial indices. + """ + # Update the bands + if indices in self._bands: + # Single band case + bands = (indices,) + else: + # Multiple bands case + bands = tuple(indices) + + # Convert the band indices into numerical indices + band_indices = convert_indices(self.bands, indices) + + # Slice the spectrum + spectrum = self._spectrum.x[band_indices] + + return ParametricComponent( + bands=bands, + bbox=self.bbox, + spectrum=spectrum, + morph_params=self.radial_params, + morph_func=self._func, + morph_grad=self._morph_grad, + morph_prox=self._morph_prox, + morph_step=self._morph_step, + prox_spectrum=self._prox_spectrum, + floor=self.floor, + ) + + def __deepcopy__(self, memo: dict[int, Any]) -> ParametricComponent: + """Create a deep copy of this component + + Parameters + ---------- + memo: + The memoization dictionary used by `copy.deepcopy`. + Returns + ------- + component : ParametricComponent + A new component that is a deep copy of this one. + """ + if id(self) in memo: + return memo[id(self)] + + component = ParametricComponent.__new__(ParametricComponent) + memo[id(self)] = component + + component.__init__( # type: ignore[misc] + bands=deepcopy(self.bands), + bbox=deepcopy(self.bbox), + spectrum=deepcopy(self.spectrum), + morph_params=deepcopy(self.radial_params), + morph_func=self._func, + morph_grad=self._morph_grad, + morph_prox=self._morph_prox, + morph_step=self._morph_step, + prox_spectrum=self._prox_spectrum, + floor=self.floor, + ) + + return component + + def __copy__(self) -> ParametricComponent: + """Create a copy of this component + + Returns + ------- + component : ParametricComponent + A new component that is a shallow copy of this one. + """ + return ParametricComponent( + bands=self.bands, + bbox=self.bbox, + spectrum=self.spectrum, + morph_params=self.radial_params, + morph_func=self._func, + morph_grad=self._morph_grad, + morph_prox=self._morph_prox, + morph_step=self._morph_step, + prox_spectrum=self._prox_spectrum, + floor=self.floor, + ) + class EllipticalParametricComponent(ParametricComponent): """A radial density/surface brightness profile with elliptical symmetry diff --git a/python/lsst/scarlet/lite/observation.py b/python/lsst/scarlet/lite/observation.py index 44fe404..56d13b9 100644 --- a/python/lsst/scarlet/lite/observation.py +++ b/python/lsst/scarlet/lite/observation.py @@ -23,6 +23,7 @@ __all__ = ["Observation", "convolve"] +from copy import deepcopy from typing import Any, cast import numpy as np @@ -63,9 +64,9 @@ def get_filter_coords(filter_values: np.ndarray, center: tuple[int, int] | None calculate `coords` on your own.""" raise ValueError(msg) center = tuple([filter_values.shape[0] // 2, filter_values.shape[1] // 2]) # type: ignore - x = np.arange(filter_values.shape[1]) - y = np.arange(filter_values.shape[0]) - x, y = np.meshgrid(x, y) + _x = np.arange(filter_values.shape[1]) + _y = np.arange(filter_values.shape[0]) + x, y = np.meshgrid(_x, _y) x -= center[1] y -= center[0] coords = np.dstack([y, x]) @@ -366,6 +367,29 @@ def __getitem__(self, indices: Any) -> Observation: new_variance = self.variance[indices] new_weights = self.weights[indices] + # If the indices is a single band, make sure to keep the band axis + if new_image.ndim == 2: + if indices in self.bands: + new_bands = (indices,) + else: + # The indices contain spatial and band indices + new_bands = (indices[0],) + new_image = Image( + new_image.data[None, :, :], + yx0=new_image.yx0, + bands=new_bands, + ) + new_variance = Image( + new_variance.data[None, :, :], + yx0=new_variance.yx0, + bands=new_bands, + ) + new_weights = Image( + new_weights.data[None, :, :], + yx0=new_weights.yx0, + bands=new_bands, + ) + # Extract the appropriate bands from the PSF bands = self.images.bands new_bands = new_image.bands @@ -385,56 +409,67 @@ def __getitem__(self, indices: Any) -> Observation: model_psf=self.model_psf, noise_rms=noise_rms, bbox=new_image.bbox, - bands=self.bands, + bands=new_bands, padding=self.padding, convolution_mode=self.mode, ) - def __copy__(self, deep: bool = False) -> Observation: + def __copy__(self) -> Observation: """Create a copy of the observation - Parameters - ---------- - deep: - Whether to perform a deep copy or not. - Returns ------- result: The copy of the observation. """ - if deep: - if self.model_psf is None: - model_psf = None - else: - model_psf = self.model_psf.copy() + return Observation( + images=self.images, + variance=self.variance, + weights=self.weights, + psfs=self.psfs, + model_psf=self.model_psf, + noise_rms=self.noise_rms, + bands=self.bands, + padding=self.padding, + convolution_mode=self.mode, + ) - if self.noise_rms is None: - noise_rms = None - else: - noise_rms = self.noise_rms.copy() + def __deepcopy__(self, memo: dict[int, Any]) -> Observation: + """Create a deep copy of the observation - if self.bands is None: - bands = None - else: - bands = tuple([b for b in self.bands]) - else: - model_psf = self.model_psf - noise_rms = self.noise_rms - bands = self.bands + Parameters + ---------- + memo: dict[int, Any] + The memoization dictionary used by `copy.deepcopy`. - return Observation( - images=self.images.copy(), - variance=self.variance.copy(), - weights=self.weights.copy(), - psfs=self.psfs.copy(), - model_psf=model_psf, - noise_rms=noise_rms, - bands=bands, - padding=self.padding, + Returns + ------- + result: + The deep copy of the observation. + """ + # Check if already copied + if id(self) in memo: + return memo[id(self)] + + # Create placeholder and add to memo FIRST + result = Observation.__new__(Observation) + memo[id(self)] = result + + # Now safely initialize the placeholder with deepcopied arguments + result.__init__( # type: ignore[misc] + images=deepcopy(self.images, memo), + variance=deepcopy(self.variance, memo), + weights=deepcopy(self.weights, memo), + psfs=deepcopy(self.psfs, memo), + model_psf=deepcopy(self.model_psf, memo), + noise_rms=deepcopy(self.noise_rms, memo), + bands=deepcopy(self.bands, memo), + padding=deepcopy(self.padding, memo), convolution_mode=self.mode, ) + return result + def copy(self, deep: bool = False) -> Observation: """Create a copy of the observation @@ -448,7 +483,9 @@ def copy(self, deep: bool = False) -> Observation: result: The copy of the observation. """ - return self.__copy__(deep) + if deep: + return self.__deepcopy__({}) + return self.__copy__() @property def shape(self) -> tuple[int, int, int]: diff --git a/python/lsst/scarlet/lite/operators.py b/python/lsst/scarlet/lite/operators.py index 1fc9f74..f1d3c6b 100644 --- a/python/lsst/scarlet/lite/operators.py +++ b/python/lsst/scarlet/lite/operators.py @@ -1,4 +1,6 @@ -from typing import Callable, Sequence, cast +from __future__ import annotations + +from typing import Any, Callable, Sequence, cast import numpy as np import numpy.typing as npt @@ -251,6 +253,32 @@ def __call__(self, image: np.ndarray, center: tuple[int, int]) -> np.ndarray: image[:] = result[1:-1, 1:-1] return image + def __copy__(self) -> Monotonicity: + """Create a shallow copy of the operator + + Returns + ------- + result: + A copy of the operator. + """ + new = Monotonicity(self.shape, self.dtype, self.auto_update, self.fit_radius) + return new + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Monotonicity: + """Create a deep copy of the operator + + Parameters + ---------- + memo: + The memoization dictionary for deep copies. + + Returns + ------- + result: + A copy of the operator. + """ + return self.__copy__() + def get_peak(image: np.ndarray, center: tuple[int, int], radius: int = 1) -> tuple[int, int]: """Search around a location for the maximum flux diff --git a/python/lsst/scarlet/lite/parameters.py b/python/lsst/scarlet/lite/parameters.py index 5790c4f..03a16ab 100644 --- a/python/lsst/scarlet/lite/parameters.py +++ b/python/lsst/scarlet/lite/parameters.py @@ -32,7 +32,8 @@ "DEFAULT_ADAPROX_FACTOR", ] -from typing import Callable, Sequence, cast +from copy import deepcopy +from typing import Any, Callable, Sequence, cast import numpy as np import numpy.typing as npt @@ -120,11 +121,50 @@ def dtype(self) -> npt.DTypeLike: """The numpy dtype of the array that is being fit.""" return self.x.dtype - def copy(self) -> Parameter: - """Copy this parameter, including all of the helper arrays.""" + def __copy__(self) -> Parameter: + """Create a shallow copy of this parameter. + + Returns + ------- + parameter: + A shallow copy of this parameter. + """ helpers = {k: v.copy() for k, v in self.helpers.items()} return Parameter(self.x.copy(), helpers, 0) + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Parameter: + """Create a deep copy of this parameter. + + Parameters + ---------- + memo: + A memoization dictionary used by `copy.deepcopy`. + Returns + ------- + parameter: + A deep copy of this parameter. + """ + helpers = {k: deepcopy(v, memo) for k, v in self.helpers.items()} + return Parameter(deepcopy(self.x, memo), helpers, 0) + + def copy(self, deep: bool = False) -> Parameter: + """Copy this parameter, including all of the helper arrays. + + Parameters + ---------- + deep: + If `True`, a deep copy is made. + If `False`, a shallow copy is made. + + Returns + ------- + parameter: + A copy of this parameter. + """ + if deep: + return self.__deepcopy__({}) + return self.__copy__() + def update(self, it: int, input_grad: np.ndarray, *args): """Update the parameter in one iteration. @@ -197,7 +237,7 @@ def __init__( z0: np.ndarray | None = None, ): if z0 is None: - z0 = x + z0 = x.copy() super().__init__( x, @@ -231,6 +271,44 @@ def update(self, it: int, input_grad: np.ndarray, *args): _x[:] = x self.t = t + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FistaParameter: + """Create a deep copy of this parameter. + + Parameters + ---------- + memo: + A memoization dictionary used by `copy.deepcopy`. + Returns + ------- + parameter: + A deep copy of this parameter. + """ + return FistaParameter( + deepcopy(self.x, memo), + self.step, + self.grad, + self.prox, + self.t, + deepcopy(self.helpers["z"], memo), + ) + + def __copy__(self) -> FistaParameter: + """Create a shallow copy of this parameter. + + Returns + ------- + parameter: + A shallow copy of this parameter. + """ + return FistaParameter( + self.x.copy(), + self.step, + self.grad, + self.prox, + self.t, + self.helpers["z"].copy(), + ) + # The following code block contains different update methods for # various implementations of ADAM. @@ -375,7 +453,7 @@ def __init__( step: Callable | float, grad: Callable | None = None, prox: Callable | None = None, - b1: float = 0.9, + b1: float | SingleItemArray = 0.9, b2: float = 0.999, eps: float = 1e-8, p: float = 0.25, @@ -418,6 +496,7 @@ def __init__( self.eps = eps self.p = p + self.scheme = scheme self.phi_psi = phi_psi[scheme] self.e_rel = prox_e_rel @@ -453,6 +532,58 @@ def update(self, it: int, input_grad: np.ndarray, *args): self.x = cast(Callable, self.prox)(_x) + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> AdaproxParameter: + """Create a deep copy of this parameter. + + Parameters + ---------- + memo: + A memoization dictionary used by `copy.deepcopy`. + Returns + ------- + parameter: + A deep copy of this parameter. + """ + return AdaproxParameter( + deepcopy(self.x, memo), + self.step, + self.grad, + self.prox, + self.b1, + self.b2, + self.eps, + self.p, + deepcopy(self.helpers["m"], memo), + deepcopy(self.helpers["v"], memo), + deepcopy(self.helpers["vhat"], memo), + scheme=self.scheme, + prox_e_rel=self.e_rel, + ) + + def __copy__(self) -> AdaproxParameter: + """Create a shallow copy of this parameter. + + Returns + ------- + parameter: + A shallow copy of this parameter. + """ + return AdaproxParameter( + self.x, + self.step, + self.grad, + self.prox, + self.b1, + self.b2, + self.eps, + self.p, + self.helpers["m"], + self.helpers["v"], + self.helpers["vhat"], + scheme=self.scheme, + prox_e_rel=self.e_rel, + ) + class FixedParameter(Parameter): """A parameter that is not updated""" @@ -463,6 +594,31 @@ def __init__(self, x: np.ndarray): def update(self, it: int, input_grad: np.ndarray, *args): pass + def __copy__(self) -> FixedParameter: + """Create a shallow copy of this parameter. + + Returns + ------- + parameter: + A shallow copy of this parameter. + """ + return FixedParameter(self.x) + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FixedParameter: + """Create a deep copy of this parameter. + + Parameters + ---------- + memo: + A memoization dictionary used by `copy.deepcopy`. + + Returns + ------- + parameter: + A deep copy of this parameter. + """ + return FixedParameter(deepcopy(self.x, memo)) + def relative_step( x: np.ndarray, diff --git a/python/lsst/scarlet/lite/source.py b/python/lsst/scarlet/lite/source.py index c17afcc..5acab55 100644 --- a/python/lsst/scarlet/lite/source.py +++ b/python/lsst/scarlet/lite/source.py @@ -24,7 +24,8 @@ __all__ = ["Source"] from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Callable, Self from .bbox import Box from .component import Component @@ -42,6 +43,7 @@ class SourceBase(ABC): """ metadata: dict[str, Any] | None = None + components: list[Component] @abstractmethod def to_data(self) -> ScarletSourceBaseData: @@ -53,6 +55,69 @@ def to_data(self) -> ScarletSourceBaseData: The `ScarletSourceData` representation of this source. """ + @abstractmethod + def __getitem__(self, indices: Any) -> Self: + """Get a sub-source corresponding to the given indices. + + Parameters + ---------- + indices: Any + The indices to use to slice the source model. + + Returns + ------- + source: SourceBase + A new source that is a sub-source of this one. + + Raises + ------ + IndexError : + If the index includes a ``Box`` or spatial indices. + """ + + @abstractmethod + def __deepcopy__(self, memo: dict[int, Any]) -> Self: + """Create a deep copy of this source. + + Parameters + ---------- + memo : dict[int, Any] + A memoization dictionary used by `copy.deepcopy`. + + Returns + ------- + source : SourceBase + A new source that is a deep copy of this one. + """ + + @abstractmethod + def __copy__(self) -> Self: + """Create a copy of this source. + + Returns + ------- + source : SourceBase + A new source that is a copy of this one. + """ + + def copy(self, deep: bool = False) -> Self: + """Create a copy of this source. + + Parameters + ---------- + deep : bool, optional + If `True`, a deep copy is made. If `False`, a shallow copy is made. + Default is `False`. + + Returns + ------- + source : Self + A new source that is a copy of this one. + """ + if deep: + return self.__deepcopy__({}) + return self.__copy__() + class Source(SourceBase): """A container for components associated with the same astrophysical object @@ -66,9 +131,14 @@ class Source(SourceBase): The components contained in the source. """ - def __init__(self, components: list[Component], metadata: dict | None = None): + def __init__( + self, + components: list[Component], + metadata: dict | None = None, + flux_weighted_image: Image | None = None, + ): self.components = components - self.flux_weighted_image: Image | None = None + self.flux_weighted_image = flux_weighted_image self.metadata = metadata @property @@ -182,3 +252,74 @@ def __str__(self): def __repr__(self): return f"Source(components={repr(self.components)})>" + + def __getitem__(self, indices: Any) -> Source: + """Get a sub-source corresponding to the given indices. + + Parameters + ---------- + indices: Any + The indices to use to slice the source model. Can be: + - A single band + - A slice with start/stop bands + - A sequence of bands + + Returns + ------- + source: Source + A new source that is a sub-source of this one. + + Raises + ------ + IndexError : + If the index includes a ``Box`` or spatial indices. + """ + flux = None if self.flux_weighted_image is None else self.flux_weighted_image[indices] + return Source( + components=[c[indices] for c in self.components], + metadata=self.metadata, + flux_weighted_image=flux, + ) + + def __deepcopy__(self, memo: dict[int, Any]) -> Source: + """Create a deep copy of this source. + + Parameters + ---------- + memo : dict[int, Any] + A memoization dictionary used by `copy.deepcopy`. + + Returns + ------- + source : SourceBase + A new source that is a deep copy of this one. + """ + # Check if already copied + if id(self) in memo: + return memo[id(self)] + + # Create placeholder and add to memo FIRST + source = Source.__new__(Source) + memo[id(self)] = source + + source.__init__( # type: ignore[misc] + components=deepcopy(self.components, memo), + metadata=deepcopy(self.metadata, memo), + flux_weighted_image=deepcopy(self.flux_weighted_image, memo), + ) + return source + + def __copy__(self) -> Source: + """Create a copy of this source. + + Returns + ------- + source : SourceBase + A new source that is a copy of this one. + """ + source = Source( + components=self.components, + metadata=self.metadata, + flux_weighted_image=self.flux_weighted_image, + ) + return source diff --git a/python/lsst/scarlet/lite/utils.py b/python/lsst/scarlet/lite/utils.py index a85fb56..6340323 100644 --- a/python/lsst/scarlet/lite/utils.py +++ b/python/lsst/scarlet/lite/utils.py @@ -19,7 +19,10 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + import sys +from typing import Any, Sequence import numpy as np import numpy.typing as npt @@ -158,6 +161,76 @@ def is_attribute_safe_to_transfer(name, value): return True +def convert_indices(sequence: Sequence, indices: Any, inclusive: bool = True) -> tuple[int, ...] | slice: + """Get either a tuple of indices or a slice object from the given sequence. + + Parameters + ---------- + sequence : Sequence + The sequence to get the indices from. This sequence should have + unique hashable elements. + + indices : Any + The indices or slice to use. Can be: + - A single element from sequence + - A slice with start/stop elements from sequence + - A sequence of elements from sequence + + inclusive : bool, optional + If True, the stop element of a slice is inclusive. + + Returns + ------- + tuple[int, ...] | slice + A tuple of indices or a slice object. + + Raises + ------ + TypeError : + If `sequence` does not support `index` and `in` operations. + IndexError : + If a single element is not found in `sequence`. + """ + # Validate that sequence has the required methods + if not hasattr(sequence, "index") or not hasattr(sequence, "__contains__"): + raise TypeError(f"'sequence' must support 'index' and 'in' operations, got {type(sequence)}") + + # Handle slice objects + if isinstance(indices, slice): + # Convert a slice of objects into a slice of array indices + try: + start = None if indices.start is None else sequence.index(indices.start) + except ValueError as e: + raise IndexError(f"Element {indices.start} not found in sequence {sequence}.") from e + try: + stop = None if indices.stop is None else sequence.index(indices.stop) + (1 if inclusive else 0) + except ValueError as e: + raise IndexError(f"Element {indices.stop} not found in sequence {sequence}.") from e + return slice(start, stop, indices.step) + + # Try to handle as a single element first + if indices in sequence: + return (sequence.index(indices),) + + # Validate that indices is iterable + if not hasattr(indices, "__iter__"): + raise IndexError(f"Element {indices} not found in sequence {sequence}.") + + # Handle sequence of indices + index_map = {value: idx for idx, value in enumerate(sequence)} + new_indices = [] + for i in indices: + try: + if i not in index_map: + raise IndexError(f"Element {i} not found in sequence {sequence}.") + except TypeError as e: + # If the + raise IndexError(f"Element {i} not found in sequence {sequence}.") from e + new_indices.append(index_map[i]) + + return tuple(new_indices) + + def continue_class(cls): """Re-open the decorated class, adding any new definitions into the original. diff --git a/tests/test_bbox.py b/tests/test_bbox.py index cc4cbb8..85856bf 100644 --- a/tests/test_bbox.py +++ b/tests/test_bbox.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from copy import deepcopy + import numpy as np from lsst.scarlet.lite import Box from utils import ScarletTestCase @@ -213,3 +215,15 @@ def test_slicing(self): self.assertBoxEqual(bbox[:3], Box((1, 2, 3), (2, 4, 6))) # check tuple index self.assertBoxEqual(bbox[(3, 1)], Box((4, 2), (8, 4))) + + def test_shallow_copy(self): + bbox = Box((3, 4, 5), (10, 20, 30)) + bbox_copy = bbox.copy() + self.assertBoxEqual(bbox, bbox_copy) + self.assertIsNot(bbox, bbox_copy) + + def test_deepcopy(self): + bbox = Box((3, 4, 5), (10, 20, 30)) + bbox_deepcopy = deepcopy(bbox) + self.assertBoxEqual(bbox, bbox_deepcopy) + self.assertIsNot(bbox, bbox_deepcopy) diff --git a/tests/test_blend.py b/tests/test_blend.py index 9604993..e27295e 100644 --- a/tests/test_blend.py +++ b/tests/test_blend.py @@ -21,45 +21,19 @@ from __future__ import annotations -from typing import Callable, cast +from typing import cast import numpy as np from lsst.scarlet.lite import Blend, Box, Image, Observation, Source -from lsst.scarlet.lite.component import Component, FactorizedComponent, default_adaprox_parameterization +from lsst.scarlet.lite.component import CubeComponent, FactorizedComponent, default_adaprox_parameterization from lsst.scarlet.lite.initialization import FactorizedInitialization from lsst.scarlet.lite.operators import Monotonicity -from lsst.scarlet.lite.parameters import Parameter from lsst.scarlet.lite.utils import integrated_circular_gaussian from numpy.testing import assert_almost_equal, assert_raises from scipy.signal import convolve as scipy_convolve from utils import ObservationData, ScarletTestCase -class DummyCubeComponent(Component): - def __init__(self, model: Image): - super().__init__(model.bands, model.bbox) - self._model = Parameter(model.data, {}, 0) - - @property - def data(self) -> np.ndarray: - return self._model.x - - def resize(self, model_box: Box) -> bool: - pass - - def update(self, it: int, input_grad: np.ndarray): - pass - - def get_model(self) -> Image: - return Image(self.data, bands=self.bands, yx0=self.bbox.origin) - - def parameterize(self, parameterization: Callable) -> None: - pass - - def to_data(self) -> DummyCubeComponent: - pass - - class TestBlend(ScarletTestCase): def setUp(self): bands = ("g", "r", "i") @@ -230,7 +204,7 @@ def test_non_factorized(self): # Remove the disk component from the first source blend.sources[0].components = blend.sources[0].components[:1] # Create a new source for the disk with a non-factorized component - component = DummyCubeComponent(Image(model, bands=self.blend.observation.bands, yx0=yx0)) + component = CubeComponent(Image(model, bands=self.blend.observation.bands, yx0=yx0), (0, 0)) blend.sources.append(Source([component])) blend.fit_spectra() @@ -254,7 +228,7 @@ def test_clipping(self): # Add an empty source zero_model = Image.from_box(Box((5, 5), (30, 0)), bands=blend.observation.bands) - component = DummyCubeComponent(zero_model) + component = CubeComponent(zero_model, (0, 0)) blend.sources.append(Source([component])) blend.fit_spectra(clip=True) @@ -262,3 +236,107 @@ def test_clipping(self): self.assertEqual(len(blend.components), 5) self.assertEqual(len(blend.sources), 5) self.assertImageAlmostEqual(blend.get_model(), self.data.images) + + def test_shallow_copy(self): + blend = self.blend + blend.metadata = {"test": "value"} + blend_copy = blend.copy() + + self.assertIsNot(blend_copy, blend) + self.assertEqual(len(blend_copy.sources), len(blend.sources)) + for source_copy, source in zip(blend_copy.sources, blend.sources): + self.assertSourceEqual(source_copy, source) + + self.assertObservationEqual(blend_copy.observation, blend.observation) + + self.assertDictEqual(blend_copy.metadata, blend.metadata) + + def test_deepcopy(self): + blend = self.blend + blend.metadata = {"test": "value"} + blend_copy = blend.copy(deep=True) + + self.assertIsNot(blend_copy, blend) + self.assertEqual(len(blend_copy.sources), len(blend.sources)) + for source_copy, source in zip(blend_copy.sources, blend.sources): + self.assertSourceEqual(source_copy, source) + + with self.assertRaises(AssertionError): + source_copy.components[0]._spectrum.x += 1 + self.assertSourceEqual(source_copy, source) + + self.assertObservationEqual(blend_copy.observation, blend.observation) + self.assertDictEqual(blend_copy.metadata, blend.metadata) + blend_copy.metadata["test"] = "new_value" + with self.assertRaises(AssertionError): + self.assertDictEqual(blend_copy.metadata, blend.metadata) + + def test_slice(self): + blend = self.blend + blend.metadata = {"test": "value"} + blend_sliced = blend["g":"r"] + self.assertEqual(len(blend.sources), len(blend_sliced.sources)) + + for source_sliced, source in zip(blend_sliced.sources, blend.sources): + self.assertSourceEqual(source_sliced, source["g":"r"]) + + self.assertObservationEqual(blend_sliced.observation, blend.observation["g":"r"]) + self.assertDictEqual(blend_sliced.metadata, blend.metadata) + + def test_reorder(self): + blend = self.blend + blend.metadata = {"test": "value"} + indices = ("i", "g", "r") + blend_reordered = blend[indices] + self.assertEqual(len(blend.sources), len(blend_reordered.sources)) + + for source_reordered, source in zip(blend_reordered.sources, blend.sources): + self.assertSourceEqual(source_reordered, source[indices]) + + self.assertObservationEqual(blend_reordered.observation, blend.observation[indices]) + self.assertDictEqual(blend_reordered.metadata, blend.metadata) + + def test_subset(self): + blend = self.blend + blend.metadata = {"test": "value"} + blend_subset = blend[("r",)] + self.assertEqual(len(blend.sources), len(blend_subset.sources)) + + for source_subset, source in zip(blend_subset.sources, blend.sources): + self.assertSourceEqual(source_subset, source["r"]) + + self.assertObservationEqual(blend_subset.observation, blend.observation["r"]) + self.assertDictEqual(blend_subset.metadata, blend.metadata) + + def test_indexing_errors(self): + blend = self.blend + + with self.assertRaises(IndexError): + blend["x"] + + with self.assertRaises(IndexError): + blend[("r", "x")] + + with self.assertRaises(IndexError): + blend["r":"x"] + + with self.assertRaises(IndexError): + blend["x":"i"] + + with self.assertRaises(IndexError): + blend["g", "x", "i"] + + with self.assertRaises(IndexError): + blend[Box((0, 0), (10, 10))] + + with self.assertRaises(IndexError): + blend[:, 10:20, 10:20] + + with self.assertRaises(IndexError): + blend[1:] + + with self.assertRaises(IndexError): + blend[1] + + with self.assertRaises(IndexError): + blend[0, 1] diff --git a/tests/test_component.py b/tests/test_component.py index e054675..85871cc 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -21,17 +21,20 @@ from __future__ import annotations -from typing import Callable +from abc import ABC +from typing import Any, Callable import numpy as np from lsst.scarlet.lite import Box, Image, Parameter from lsst.scarlet.lite.component import ( Component, + CubeComponent, FactorizedComponent, default_adaprox_parameterization, default_fista_parameterization, ) from lsst.scarlet.lite.operators import Monotonicity +from lsst.scarlet.lite.utils import integrated_circular_gaussian from numpy.testing import assert_almost_equal, assert_array_equal from utils import ScarletTestCase @@ -52,8 +55,92 @@ def parameterize(self, parameterization: Callable) -> None: def to_data(self) -> DummyComponent: pass + def __getitem__(self, indices: Any) -> DummyComponent: + pass + + def __copy__(self) -> DummyComponent: + pass + + def __deepcopy__(self, memo: dict[int, Any]) -> DummyComponent: + pass + + +class _ComponentTestBase(ABC): + def test_slice(self): + component = self.component + component_sliced = component["g":"r"] + self.assertTupleEqual(component_sliced.bands, ("g", "r")) + np.testing.assert_array_equal(component_sliced.get_model(), component.get_model().data[0:2]) + + def test_reorder(self): + component = self.component + indices = ("i", "g", "r") + component_reordered = component["i", "g", "r"] + self.assertTupleEqual(component_reordered.bands, indices) + np.testing.assert_array_equal( + component_reordered.get_model(), + component.get_model().data[(2, 0, 1),], + ) + + component_reordered = component["igr"] + self.assertTupleEqual(component_reordered.bands, indices) + np.testing.assert_array_equal( + component_reordered.get_model(), + component.get_model().data[(2, 0, 1),], + ) + + def test_subset(self): + component = self.component + indices = ("r",) + component_subset = component["r"] + self.assertTupleEqual(component_subset.bands, indices) + np.testing.assert_array_equal( + component_subset.get_model(), + component.get_model().data[1:2,], + ) + + component = self.component.copy(deep=True) + component._bands = ("ab", "cd", "ef") + indices = "ab" + component_reordered = component["ab"] + self.assertTupleEqual(component_reordered.bands, (indices,)) + np.testing.assert_array_equal( + component_reordered.get_model(), + component.get_model().data[0:1,], + ) + + def test_indexing_errors(self): + component = self.component + print("bands", component.bands) + with self.assertRaises(IndexError): + component["z"] + + with self.assertRaises(IndexError): + component["r":"z"] + + with self.assertRaises(IndexError): + component["z":"i"] + + with self.assertRaises(IndexError): + component["g", "z", "i"] + + with self.assertRaises(IndexError): + component[Box((0, 0), (10, 10))] + + with self.assertRaises(IndexError): + component[:, 10:20, 10:20] + + with self.assertRaises(IndexError): + component[1:] + + with self.assertRaises(IndexError): + component[1] -class TestFactorizedComponent(ScarletTestCase): + with self.assertRaises(IndexError): + component[0, 1] + + +class TestFactorizedComponent(_ComponentTestBase, ScarletTestCase): def setUp(self) -> None: spectrum = np.arange(3).astype(np.float32) morph = np.arange(20).reshape(4, 5).astype(np.float32) @@ -246,3 +333,82 @@ def test_parameterization(self): with self.assertRaises(NotImplementedError): default_adaprox_parameterization(DummyComponent(*params)) + + def test_shallow_copy(self): + component = self.component + component.monotonicity = Monotonicity((11, 11), fit_radius=0) + + component_copy = component.copy() + + self.assertIsNot(component, component_copy) + np.testing.assert_array_equal(component._spectrum.x, component_copy._spectrum.x) + np.testing.assert_array_equal(component._morph.x, component_copy._morph.x) + self.assertIs(component.bbox, component_copy.bbox) + self.assertIs(component.peak, component_copy.peak) + self.assertIs(component.bg_thresh, component_copy.bg_thresh) + self.assertIs(component.monotonicity, component_copy.monotonicity) + + def test_deep_copy(self): + component = self.component + component.monotonicity = Monotonicity((11, 11), fit_radius=0) + component_deepcopy = component.copy(deep=True) + + self.assertIsNot(component, component_deepcopy) + + np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x) + component_deepcopy._spectrum.x += 1 + with self.assertRaises(AssertionError): + np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x) + + np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x) + component_deepcopy._morph.x += 1 + with self.assertRaises(AssertionError): + np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x) + + self.assertIsNot(component.bbox, component_deepcopy.bbox) + self.assertBoxEqual(component.bbox, component_deepcopy.bbox) + + self.assertTupleEqual(component.peak, component_deepcopy.peak) + self.assertEqual(component.bg_thresh, component_deepcopy.bg_thresh) + self.assertIsNot(component.monotonicity, component_deepcopy.monotonicity) + + +class TestCubeComponent(_ComponentTestBase, ScarletTestCase): + def setUp(self) -> None: + super().setUp() + self.bands = tuple("gri") + peak = (27, 32) + bbox = Box((15, 15), (20, 25)) + morph = integrated_circular_gaussian(sigma=0.8).astype(np.float32) + spectrum = np.arange(3, dtype=np.float32) + model = morph[None, :, :] * spectrum[:, None, None] + model_image = Image(model, yx0=bbox.origin, bands=self.bands) + self.component = CubeComponent(model=model_image, peak=peak) + + def test_constructor(self): + component = self.component + self.assertIsInstance(component._model, Image) + np.testing.assert_array_equal(component._model.data, self.component._model.data) + self.assertTupleEqual(component.bands, self.bands) + self.assertBoxEqual(component.bbox, Box((15, 15), (20, 25))) + self.assertTupleEqual(component.peak, (27, 32)) + + def test_shallow_copy(self): + component = self.component + component_copy = component.copy() + + self.assertIsNot(component_copy, component) + self.assertTupleEqual(component_copy.peak, component.peak) + self.assertImageEqual(component_copy._model, component._model) + + def test_deep_copy(self): + component = self.component + component_copy = component.copy(deep=True) + + self.assertIsNot(component, component_copy) + + self.assertTupleEqual(component_copy.peak, component.peak) + self.assertImageEqual(component_copy._model, component._model) + with self.assertRaises(AssertionError): + component_copy._model._data -= 1 + self.assertImageEqual(component_copy._model, component._model) diff --git a/tests/test_io.py b/tests/test_io.py index 9e8cb4a..9b2b8ed 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -24,8 +24,8 @@ import numpy as np from lsst.scarlet.lite import Blend, Image, Observation, io +from lsst.scarlet.lite.component import CubeComponent from lsst.scarlet.lite.initialization import FactorizedInitialization -from lsst.scarlet.lite.io import ComponentCube from lsst.scarlet.lite.operators import Monotonicity from lsst.scarlet.lite.utils import integrated_circular_gaussian from numpy.testing import assert_almost_equal @@ -107,7 +107,7 @@ def test_cube_component(self): blend.sources[i].metadata = {"id": f"peak-{i}"} component = blend.sources[-1].components[-1] # Replace one of the components with a Free-Form component. - blend.sources[-1].components[-1] = ComponentCube( + blend.sources[-1].components[-1] = CubeComponent( model=component.get_model(), peak=component.peak, ) diff --git a/tests/test_measure.py b/tests/test_measure.py index 1a6eae3..9082cf6 100644 --- a/tests/test_measure.py +++ b/tests/test_measure.py @@ -22,8 +22,8 @@ import os import numpy as np -from lsst.scarlet.lite import Blend, Image, Observation, Source, io -from lsst.scarlet.lite.component import default_adaprox_parameterization +from lsst.scarlet.lite import Blend, Image, Observation, Source +from lsst.scarlet.lite.component import CubeComponent, default_adaprox_parameterization from lsst.scarlet.lite.initialization import FactorizedInitialization from lsst.scarlet.lite.measure import calculate_snr from lsst.scarlet.lite.operators import Monotonicity @@ -84,7 +84,7 @@ def test_conserve_flux(self): blend.sources.append( Source( [ - io.ComponentCube( + CubeComponent( model=Image( np.ones(observation.shape, dtype=observation.dtype), observation.bands, diff --git a/tests/test_observation.py b/tests/test_observation.py index 65dc1d1..95bbe30 100644 --- a/tests/test_observation.py +++ b/tests/test_observation.py @@ -285,3 +285,22 @@ def test_slicing(self): observation.noise_rms[1:4], ) self.assertBoxEqual(sliced_observation.bbox, new_box) + + def test_shallow_copy(self): + observation_copy = self.observation.copy() + self.assertObservationEqual(observation_copy, self.observation) + + def test_deep_copy(self): + observation_copy = self.observation.copy(deep=True) + self.assertObservationEqual(observation_copy, self.observation) + + # Modify the copy and check that the original is unchanged + observation_copy.images._data += 1 + with self.assertRaises(AssertionError): + self.assertImageEqual(observation_copy.images, self.observation.images) + observation_copy.variance._data += 1 + with self.assertRaises(AssertionError): + self.assertImageEqual(observation_copy.variance, self.observation.variance) + observation_copy.weights._data += 1 + with self.assertRaises(AssertionError): + self.assertImageEqual(observation_copy.weights, self.observation.weights) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 043398e..cb88aa9 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -140,3 +140,82 @@ def test_fixed_parameter(self): param = FixedParameter(x) param.update(10, np.arange(10) * 2) assert_array_equal(param.x, x) + + def test_shallow_copy(self): + x = np.arange(10, dtype=float) + + # FistaParameter + param = FistaParameter(x, 0.1) + param_copy = param.copy() + self.assertIsInstance(param_copy, FistaParameter) + + assert_array_equal(param.x, param_copy.x) + assert_array_equal(param.helpers["z"], param_copy.helpers["z"]) + + # AdaproxParameter + param = AdaproxParameter(x, 0.1) + param_copy = param.copy() + self.assertIsInstance(param_copy, AdaproxParameter) + + assert_array_equal(param.x, param_copy.x) + assert_array_equal(param.helpers["m"], param_copy.helpers["m"]) + assert_array_equal(param.helpers["v"], param_copy.helpers["v"]) + assert_array_equal(param.helpers["vhat"], param_copy.helpers["vhat"]) + + # FixedParameter + param = FixedParameter(x) + param_copy = param.copy() + self.assertIsInstance(param_copy, FixedParameter) + assert_array_equal(param.x, param_copy.x) + + def test_deep_copy(self): + x = np.arange(10, dtype=float) + + # FistaParameter + param = FistaParameter(x, 0.1) + param_deepcopy = param.copy(deep=True) + self.assertIsInstance(param_deepcopy, FistaParameter) + + assert_array_equal(param.x, param_deepcopy.x) + param_deepcopy.x += 1 + with self.assertRaises(AssertionError): + assert_array_equal(param.x, param_deepcopy.x) + + assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"]) + param_deepcopy.helpers["z"] += 1 + with self.assertRaises(AssertionError): + assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"]) + + # AdaproxParameter + param = AdaproxParameter(x, 0.1) + param_deepcopy = param.copy(deep=True) + self.assertIsInstance(param_deepcopy, AdaproxParameter) + + assert_array_equal(param.x, param_deepcopy.x) + param_deepcopy.x += 1 + with self.assertRaises(AssertionError): + assert_array_equal(param.x, param_deepcopy.x) + + assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"]) + param_deepcopy.helpers["m"] = -1 + with self.assertRaises(AssertionError): + assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"]) + + assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"]) + param_deepcopy.helpers["v"] = -1 + with self.assertRaises(AssertionError): + assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"]) + + assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"]) + param_deepcopy.helpers["vhat"] = -1 + with self.assertRaises(AssertionError): + assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"]) + + # FixedParameter + param = FixedParameter(x) + param_deepcopy = param.copy(deep=True) + self.assertIsInstance(param_deepcopy, FixedParameter) + assert_array_equal(param.x, param_deepcopy.x) + param_deepcopy.x += 1 + with self.assertRaises(AssertionError): + assert_array_equal(param.x, param_deepcopy.x) diff --git a/tests/test_source.py b/tests/test_source.py index 437c6c4..591a4bc 100644 --- a/tests/test_source.py +++ b/tests/test_source.py @@ -27,8 +27,34 @@ class TestSource(ScarletTestCase): - def test_constructor(self): - # Test empty source + def setUp(self) -> None: + super().setUp() + self.bands = tuple("grizy") + self.center = (27, 32) + self.morph1 = integrated_circular_gaussian(sigma=0.8).astype(np.float32) + self.spectrum1 = np.arange(5).astype(np.float32) + self.component_box1 = Box((15, 15), (20, 25)) + self.morph2 = integrated_circular_gaussian(sigma=2.1).astype(np.float32) + self.spectrum2 = np.arange(5)[::-1].astype(np.float32) + self.component_box2 = Box((15, 15), (10, 35)) + self.components = [ + FactorizedComponent( + self.bands, + self.spectrum1, + self.morph1, + self.component_box1, + self.center, + ), + FactorizedComponent( + self.bands, + self.spectrum2, + self.morph2, + self.component_box2, + self.center, + ), + ] + + def test_empty_constructor(self): source = Source([]) self.assertEqual(source.n_components, 0) @@ -38,74 +64,41 @@ def test_constructor(self): self.assertBoxEqual(source.bbox, Box((0, 0))) self.assertTupleEqual(source.bands, ()) - # Test a source with a single component - bands = tuple("grizy") - center = (27, 32) - morph1 = integrated_circular_gaussian(sigma=0.8).astype(np.float32) - spectrum1 = np.arange(5).astype(np.float32) - component_box1 = Box((15, 15), (20, 25)) - components = [ - FactorizedComponent( - bands, - spectrum1, - morph1, - component_box1, - center, - ), - ] - source = Source(components) + def test_single_component_constructor(self): + source = Source(self.components[:1]) self.assertEqual(source.n_components, 1) - self.assertTupleEqual(source.center, center) + self.assertTupleEqual(source.center, self.center) self.assertTupleEqual(source.source_center, (7, 7)) self.assertFalse(source.is_null) - self.assertBoxEqual(source.bbox, component_box1) - self.assertTupleEqual(source.bands, bands) + self.assertBoxEqual(source.bbox, self.component_box1) + self.assertTupleEqual(source.bands, self.bands) self.assertImageEqual( source.get_model(), Image( - spectrum1[:, None, None] * morph1[None, :, :], - yx0=component_box1.origin, - bands=bands, + self.spectrum1[:, None, None] * self.morph1[None, :, :], + yx0=self.component_box1.origin, + bands=self.bands, ), ) self.assertIsNone(source.get_model(True)) self.assertEqual(source.get_model().dtype, np.float32) + def test_multiple_component_constructor(self): # Test a source with multiple components - morph2 = integrated_circular_gaussian(sigma=2.1).astype(np.float32) - spectrum2 = np.arange(5)[::-1].astype(np.float32) - component_box2 = Box((15, 15), (10, 35)) - - components = [ - FactorizedComponent( - bands, - spectrum1, - morph1, - component_box1, - center, - ), - FactorizedComponent( - bands, - spectrum2, - morph2, - component_box2, - center, - ), - ] - source = Source(components) + source = Source(self.components) self.assertEqual(source.n_components, 2) - self.assertTupleEqual(source.center, center) + self.assertTupleEqual(source.center, self.center) self.assertTupleEqual(source.source_center, (17, 7)) self.assertFalse(source.is_null) self.assertBoxEqual(source.bbox, Box((25, 25), (10, 25))) - self.assertTupleEqual(source.bands, bands) + self.assertTupleEqual(source.bands, self.bands) self.assertEqual(str(source), "Source<2>") self.assertEqual(source.get_model().dtype, np.float32) model = np.zeros((5, 25, 25), dtype=np.float32) - model[:, 10:25, :15] = spectrum1[:, None, None] * morph1[None, :, :] - model[:, :15, 10:25] += spectrum2[:, None, None] * morph2[None, :, :] - model = Image(model, yx0=(10, 25), bands=tuple("grizy")) + model[:, 10:25, :15] = self.spectrum1[:, None, None] * self.morph1[None, :, :] + model[:, :15, 10:25] += self.spectrum2[:, None, None] * self.morph2[None, :, :] + model = Image(model, yx0=(10, 25), bands=self.bands) self.assertImageEqual( source.get_model(), @@ -116,3 +109,102 @@ def test_constructor(self): source = Source([]) result = source.get_model() self.assertEqual(result, 0) + + def test_shallow_copy(self): + source = Source(self.components) + source_copy = source.copy() + + self.assertIsNot(source, source_copy) + self.assertEqual(source.n_components, 2) + self.assertEqual(source.n_components, source_copy.n_components) + self.assertFactorizedComponentEqual(source.components[0], source_copy.components[0]) + self.assertFactorizedComponentEqual(source.components[1], source_copy.components[1]) + self.assertIs(source.flux_weighted_image, source_copy.flux_weighted_image) + self.assertIs(source.metadata, source_copy.metadata) + + def test_deepcopy(self): + source = Source(self.components) + source_deepcopy = source.copy(deep=True) + + self.assertIsNot(source, source_deepcopy) + self.assertEqual(source.n_components, source_deepcopy.n_components) + for comp, comp_deepcopy in zip(source.components, source_deepcopy.components): + self.assertIsNot(comp, comp_deepcopy) + self.assertFactorizedComponentEqual(comp, comp_deepcopy) + comp_deepcopy._spectrum.x += 1 + with self.assertRaises(AssertionError): + np.testing.assert_array_equal(comp._spectrum.x, comp_deepcopy._spectrum.x) + comp_deepcopy._morph.x += 1 + with self.assertRaises(AssertionError): + np.testing.assert_array_equal(comp._morph.x, comp_deepcopy._morph.x) + + def test_slice(self): + source = Source(self.components) + source_sliced = source["g":"r"] + self.assertTupleEqual(source_sliced.bands, ("g", "r")) + self.assertEqual(source.n_components, source_sliced.n_components) + + for comp, comp_sliced in zip(source.components, source_sliced.components): + self.assertFactorizedComponentEqual(comp["g":"r"], comp_sliced) + + def test_reorder(self): + source = Source(self.components) + indices = ("i", "g", "r") + source_reordered = source[indices] + self.assertTupleEqual(source_reordered.bands, indices) + self.assertEqual(source.n_components, source_reordered.n_components) + for comp, comp_reordered in zip(source.components, source_reordered.components): + self.assertFactorizedComponentEqual(comp[indices], comp_reordered) + + source_reordered = source["igr"] + self.assertTupleEqual(source_reordered.bands, indices) + self.assertEqual(source.n_components, source_reordered.n_components) + for comp, comp_reordered in zip(source.components, source_reordered.components): + self.assertFactorizedComponentEqual(comp["igr"], comp_reordered) + + def test_subset(self): + source = Source(self.components) + source_subset = source[("r",)] + self.assertTupleEqual(source_subset.bands, ("r",)) + self.assertEqual(source.n_components, source_subset.n_components) + for comp, comp_subset in zip(source.components, source_subset.components): + self.assertFactorizedComponentEqual(comp["r"], comp_subset) + + source = source.copy(deep=True) + for comp in source.components: + comp._bands = ("ab", "cd", "ef") + source_subset = source["ab"] + self.assertTupleEqual(source_subset.bands, ("ab",)) + self.assertEqual(source.n_components, source_subset.n_components) + for comp, comp_subset in zip(source.components, source_subset.components): + self.assertFactorizedComponentEqual(comp["ab"], comp_subset) + + def test_indexing_errors(self): + source = Source(self.components) + + with self.assertRaises(IndexError): + source["x"] + + with self.assertRaises(IndexError): + source["r":"x"] + + with self.assertRaises(IndexError): + source["x":"i"] + + with self.assertRaises(IndexError): + source["g", "x", "i"] + + with self.assertRaises(IndexError): + source[Box((0, 0), (10, 10))] + + with self.assertRaises(IndexError): + source[:, 10:20, 10:20] + + with self.assertRaises(IndexError): + source[1:] + + with self.assertRaises(IndexError): + source[1] + + with self.assertRaises(IndexError): + source[0, 1] diff --git a/tests/utils.py b/tests/utils.py index 9ef72e3..c44e67b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,13 +21,15 @@ import sys import traceback -from typing import Sequence +from typing import Sequence, cast from unittest import TestCase import numpy as np from lsst.scarlet.lite.bbox import Box +from lsst.scarlet.lite.component import FactorizedComponent from lsst.scarlet.lite.fft import match_kernel from lsst.scarlet.lite.image import Image +from lsst.scarlet.lite.source import Source from lsst.scarlet.lite.utils import integrated_circular_gaussian from numpy.testing import assert_almost_equal, assert_array_equal from numpy.typing import DTypeLike @@ -200,3 +202,38 @@ def assertImageAlmostEqual(self, image: Image, truth: Image, decimal: int = 7): def assertImageEqual(self, image: Image, truth: Image): # noqa: N802 self.assertImageAlmostEqual(image, truth) assert_array_equal(image.data, truth.data) + + def assertFactorizedComponentEqual( # noqa: N802 + self, + component: FactorizedComponent, + truth: FactorizedComponent, + ): + self.assertTupleEqual(component.bands, truth.bands) + self.assertTupleEqual(component.peak, truth.peak) + np.testing.assert_array_equal(component._spectrum.x, truth._spectrum.x) + np.testing.assert_array_equal(component._morph.x, truth._morph.x) + self.assertBoxEqual(component.bbox, truth.bbox) + self.assertEqual(component.bg_rms, truth.bg_rms) + self.assertEqual(component.bg_thresh, truth.bg_thresh) + self.assertEqual(component.floor, truth.floor) + self.assertEqual(component.padding, truth.padding) + self.assertEqual(component.is_symmetric, truth.is_symmetric) + + def assertSourceEqual(self, source: Source, truth: Source): # noqa: N802 + self.assertEqual(source.n_components, truth.n_components) + self.assertBoxEqual(source.bbox, truth.bbox) + self.assertTupleEqual(source.bands, truth.bands) + for comp, comp_truth in zip(source.components, truth.components): + self.assertFactorizedComponentEqual( + cast(FactorizedComponent, comp), + cast(FactorizedComponent, comp_truth), + ) + + def assertObservationEqual(self, obs: ObservationData, truth: ObservationData): # noqa: N802 + self.assertImageEqual(obs.images, truth.images) + self.assertImageEqual(obs.variance, truth.variance) + self.assertImageEqual(obs.weights, truth.weights) + assert_array_equal(obs.psfs, truth.psfs) + assert_array_equal(obs.model_psf, truth.model_psf) + assert_array_equal(obs.noise_rms, truth.noise_rms) + self.assertBoxEqual(obs.bbox, truth.bbox)