Skip to content

Commit

Permalink
Stamp size again (#505)
Browse files Browse the repository at this point in the history
* change it so stamp size is used as a mandatory argument of sampling functions

* stamp size is automatically propagated from sampling functions

* propagate changes to stamp size
  • Loading branch information
ismael-mendoza authored Jun 26, 2024
1 parent 3fe5f6d commit 04ac79b
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 50 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ sampling_function = btk.sampling_functions.DefaultSampling(
# setup generator to create batches of blends
batch_size = 100
draw_generator = btk.draw_blends.CatsimGenerator(
catalog, sampling_function, survey, batch_size, stamp_size
catalog, sampling_function, survey, batch_size
)

# get batch of blends
Expand Down
7 changes: 1 addition & 6 deletions btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def __init__(
sampling_function: SamplingFunction,
surveys: Union[List[Survey], Survey],
batch_size: int = 8,
stamp_size: float = 24.0,
njobs: int = 1,
verbose: bool = False,
use_bar: bool = False,
Expand All @@ -165,7 +164,6 @@ def __init__(
surveys: List of BTK Survey objects or
single BTK Survey object.
batch_size: Number of blends generated per batch
stamp_size: Size of the stamps, in arcseconds
njobs: Number of njobs to use; defines the number of minibatches
verbose: Indicates whether additionnal information should be printed
use_bar: Whether to use progress bar (default: False)
Expand All @@ -187,7 +185,7 @@ def __init__(
self.max_number = self.blend_generator.max_number
self.apply_shear = apply_shear
self.augment_data = augment_data
self.stamp_size = stamp_size
self.stamp_size = sampling_function.stamp_size
self.use_bar = use_bar
self._set_surveys(surveys)

Expand Down Expand Up @@ -523,7 +521,6 @@ def __init__(
sampling_function: SamplingFunction,
surveys: List[Survey],
batch_size: int = 8,
stamp_size: float = 24.0,
njobs: int = 1,
verbose: bool = False,
add_noise: str = "all",
Expand All @@ -541,7 +538,6 @@ def __init__(
sampling_function: See parent class.
surveys: See parent class.
batch_size: See parent class.
stamp_size: See parent class.
njobs: See parent class.
verbose: See parent class.
add_noise: See parent class.
Expand All @@ -563,7 +559,6 @@ def __init__(
sampling_function,
surveys,
batch_size,
stamp_size,
njobs,
verbose,
use_bar,
Expand Down
68 changes: 40 additions & 28 deletions btk/sampling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,16 @@ class SamplingFunction(ABC):
galaxies chosen for the blend.
"""

def __init__(self, max_number: int, min_number: int = 1, seed=DEFAULT_SEED):
def __init__(self, stamp_size: int, max_number: int, min_number: int = 1, seed=DEFAULT_SEED):
"""Initializes the SamplingFunction.
Args:
stamp_size: The size of the stamp in arcseconds.
max_number: maximum number of catalog entries returned from sample.
min_number: minimum number of catalog entries returned from sample. (Default: 1)
seed: Seed to initialize randomness for reproducibility. (Default: btk.DEFAULT_SEED)
"""
self.stamp_size = stamp_size
self.min_number = min_number
self.max_number = max_number

Expand All @@ -93,9 +95,9 @@ class DefaultSampling(SamplingFunction):

def __init__(
self,
stamp_size: float = 24.0,
max_number: int = 2,
min_number: int = 1,
stamp_size: float = 24.0,
max_shift: Optional[float] = None,
seed: int = DEFAULT_SEED,
max_mag: float = 25.3,
Expand All @@ -105,18 +107,19 @@ def __init__(
"""Initializes default sampling function.
Args:
stamp_size: Size of the desired stamp.
max_number: Defined in parent class
min_number: Defined in parent class
stamp_size: Size of the desired stamp.
max_shift: Magnitude of maximum value of shift. If None then it
is set as one-tenth the stamp size. (in arcseconds)
seed: Seed to initialize randomness for reproducibility.
min_mag: Minimum magnitude allowed in samples
max_mag: Maximum magnitude allowed in samples.
mag_name: Name of the magnitude column in the catalog.
"""
super().__init__(max_number=max_number, min_number=min_number, seed=seed)
self.stamp_size = stamp_size
super().__init__(
stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed
)
self.max_shift = max_shift if max_shift is not None else self.stamp_size / 10.0
self.min_mag, self.max_mag = min_mag, max_mag
self.mag_name = mag_name
Expand Down Expand Up @@ -165,10 +168,10 @@ class DensitySampling(SamplingFunction):

def __init__(
self,
stamp_size: float = 24.0,
max_number: int = 40,
min_number: int = 0,
density: float = 185,
stamp_size: float = 24.0,
max_shift: Optional[float] = None,
seed: int = DEFAULT_SEED,
max_mag: float = 27.3,
Expand All @@ -178,20 +181,21 @@ def __init__(
"""Initializes default sampling function.
Args:
stamp_size: Size of the desired stamp (in arcseconds)
max_number: Defined in parent class
min_number: Defined in parent class
density: Density of galaxies, default corresponds to 27.3 i-band magnitude
cut in CATSIM catalog. (in counts / sq. arcmin)
stamp_size: Size of the desired stamp (in arcseconds)
max_shift: Magnitude of maximum value of shift. If None, then centroids can fall
anywhere within the image. (in arcseconds)
seed: Seed to initialize randomness for reproducibility.
min_mag: Minimum magnitude allowed in samples
max_mag: Maximum magnitude allowed in samples.
mag_name: Name of the magnitude column in the catalog.
"""
super().__init__(max_number=max_number, min_number=min_number, seed=seed)
self.stamp_size = stamp_size
super().__init__(
stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed
)
self.min_mag, self.max_mag = min_mag, max_mag
self.mag_name = mag_name
self.max_shift = max_shift if max_shift else self.stamp_size / 2
Expand Down Expand Up @@ -249,23 +253,24 @@ class BasicSampling(SamplingFunction):

def __init__(
self,
stamp_size: float = 24.0,
max_number: int = 4,
min_number: int = 1,
stamp_size: float = 24.0,
mag_name: str = "i_ab",
seed: int = DEFAULT_SEED,
):
"""Initializes the basic sampling function.
Args:
stamp_size: Size of the desired stamp.
max_number: Defined in parent class.
min_number: Defined in parent class.
stamp_size: Size of the desired stamp.
seed: Seed to initialize randomness for reproducibility.
mag_name: Name of the magnitude column in the catalog for cuts.
"""
super().__init__(max_number=max_number, min_number=min_number, seed=seed)
self.stamp_size = stamp_size
super().__init__(
stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed
)
self.mag_name = mag_name

if min_number < 1:
Expand Down Expand Up @@ -328,24 +333,33 @@ class DefaultSamplingShear(DefaultSampling):

def __init__(
self,
stamp_size: float = 24.0,
max_number: int = 2,
min_number: int = 1,
stamp_size: float = 24.0,
max_shift: Optional[float] = None,
seed=DEFAULT_SEED,
max_mag: float = 25.3,
min_mag: float = -np.inf,
mag_name: str = "i_ab",
shear: Tuple[float, float] = (0.0, 0.0),
):
"""Initializes default sampling function with shear.
Args:
stamp_size: Defined in parent class.
max_number: Defined in parent class.
min_number: Defined in parent class.
stamp_size: Defined in parent class.
max_shift: Defined in parent class.
seed: Defined in parent class.
max_mag: Defined in parent class.
min_mag: Defined in parent class.
mag_name: Defined in parent class.
shear: Constant (g1,g2) shear to apply to every galaxy.
"""
super().__init__(max_number, min_number, stamp_size, max_shift, seed)
super().__init__(
stamp_size, max_number, min_number, max_shift, seed, max_mag, min_mag, mag_name
)
self.shear = shear

def __call__(self, table: Table, **kwargs) -> Table:
Expand Down Expand Up @@ -386,7 +400,7 @@ def __init__(
bright_cut: Magnitude cut for bright galaxy. (Default: 25.3)
dim_cut: Magnitude cut for dim galaxy. (Default: 28.0)
"""
super().__init__(2, 1, seed)
super().__init__(stamp_size, 2, 1, seed)
self.stamp_size = stamp_size
self.max_shift = max_shift if max_shift is not None else self.stamp_size / 10.0
self.mag_name = mag_name
Expand Down Expand Up @@ -427,8 +441,8 @@ class RandomSquareSampling(SamplingFunction):

def __init__(
self,
max_number: int = 2,
stamp_size: float = 24.0,
max_number: int = 2,
seed: int = DEFAULT_SEED,
max_mag: float = 25.3,
min_mag: float = -np.inf,
Expand All @@ -437,16 +451,14 @@ def __init__(
"""Initializes the RandomSquareSampling sampling function.
Args:
max_number: Defined in parent class
stamp_size: Size of the desired stamp (arcsec).
max_number: Defined in parent class
seed: Seed to initialize randomness for reproducibility.
min_mag: Minimum magnitude allowed in samples
max_mag: Maximum magnitude allowed in samples.
mag_name: Name of the magnitude column in the catalog.
"""
super().__init__(max_number=max_number, min_number=0, seed=seed)
self.stamp_size = stamp_size
self.max_number = max_number
super().__init__(stamp_size=stamp_size, max_number=max_number, min_number=0, seed=seed)
self.max_mag = max_mag
self.min_mag = min_mag
self.mag_name = mag_name
Expand Down Expand Up @@ -527,10 +539,10 @@ class FriendsOfFriendsSampling(SamplingFunction):

def __init__(
self,
stamp_size: float = 24.0,
max_number: int = 10,
min_number: int = 2,
link_distance: int = 2.5,
stamp_size: float = 24.0,
seed: int = DEFAULT_SEED,
min_mag: float = -np.inf,
max_mag: float = 25.3,
Expand All @@ -539,20 +551,20 @@ def __init__(
"""Initializes the FriendsOfFriendsSampling sampling function.
Args:
max_number: Defined in parent class
min_number: Defined in parent class
stamp_size: Defined in parent class.
max_number: Defined in parent class.
min_number: Defined in parent class.
link_distance: Minimum linkage distance to form a group (arcsec).
stamp_size: Size of the desired stamp (arcsec).
seed: Seed to initialize randomness for reproducibility.
min_mag: Minimum magnitude allowed in samples
max_mag: Maximum magnitude allowed in samples.
mag_name: Name of the magnitude column in the catalog.
"""
super().__init__(max_number=max_number, min_number=min_number, seed=seed)
self.stamp_size = stamp_size
super().__init__(
stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed
)
self.link_distance = link_distance
self.max_number = max_number
self.min_number = min_number
self.max_mag = max_mag
self.min_mag = min_mag
self.mag_name = mag_name
Expand Down
2 changes: 0 additions & 2 deletions notebooks/00-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@
" sampling_function,\n",
" LSST,\n",
" batch_size=batch_size,\n",
" stamp_size=stamp_size,\n",
" njobs=1,\n",
" add_noise=\"all\",\n",
" seed=seed, # use same seed here\n",
Expand Down Expand Up @@ -1066,7 +1065,6 @@
" sampling_function,\n",
" LSST,\n",
" batch_size=batch_size,\n",
" stamp_size=stamp_size,\n",
" njobs=1,\n",
" add_noise=\"all\",\n",
" seed=seed, # use same seed here\n",
Expand Down
1 change: 0 additions & 1 deletion notebooks/01-advanced-deblending.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
" sampling_function,\n",
" LSST,\n",
" batch_size=batch_size,\n",
" stamp_size=24.0,\n",
" njobs=1,\n",
" add_noise=\"all\",\n",
" seed=SEED,\n",
Expand Down
9 changes: 4 additions & 5 deletions notebooks/01-advanced-generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,16 @@
"\n",
" def __init__(\n",
" self,\n",
" stamp_size: float = 24.0,\n",
" max_number: int = 2,\n",
" min_number: int = 1,\n",
" stamp_size: float = 24.0,\n",
" max_shift: Optional[float] = None,\n",
" seed: int = DEFAULT_SEED,\n",
" max_mag: float = 25.3,\n",
" min_mag: float = -np.inf,\n",
" mag_name: str = \"i_ab\",\n",
" ):\n",
" super().__init__(max_number=max_number, min_number=min_number, seed=seed)\n",
" super().__init__(stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed)\n",
" self.stamp_size = stamp_size\n",
" self.max_shift = max_shift if max_shift is not None else self.stamp_size / 10.0\n",
" self.min_mag, self.max_mag = min_mag, max_mag\n",
Expand Down Expand Up @@ -553,12 +553,11 @@
"from btk.sampling_functions import DefaultSampling\n",
"from btk.draw_blends import CosmosGenerator\n",
"\n",
"sampling_func = DefaultSampling(3, 1,\n",
" stamp_size=16,\n",
"sampling_func = DefaultSampling(16, 3, 1,\n",
" mag_name=\"MAG\" # use the name available in the COSMOS catalog.\n",
" )\n",
"\n",
"generator = CosmosGenerator(cosmos_cat, sampling_func, lsst, batch_size=9, stamp_size=16)"
"generator = CosmosGenerator(cosmos_cat, sampling_func, lsst, batch_size=9)"
]
},
{
Expand Down
1 change: 0 additions & 1 deletion notebooks/01-advanced-metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
" sampling_function,\n",
" survey,\n",
" batch_size=batch_size,\n",
" stamp_size=stamp_size,\n",
" njobs=1,\n",
" add_noise=\"all\",\n",
" seed=0,\n",
Expand Down
1 change: 0 additions & 1 deletion notebooks/02-advanced-plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@
" sampling_function,\n",
" survey,\n",
" batch_size=batch_size,\n",
" stamp_size=stamp_size,\n",
" njobs=1,\n",
" add_noise=\"background\",\n",
" seed=seed, # use same seed here\n",
Expand Down
1 change: 0 additions & 1 deletion tests/test_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def test_cosmos_generator(data_dir):
sampling_function,
survey,
batch_size=batch_size,
stamp_size=stamp_size,
njobs=1,
add_noise="all",
seed=SEED,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_deblenders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def test_sep(data_dir):
sampling_function,
survey,
batch_size=batch_size,
stamp_size=24.0,
njobs=1,
add_noise="all",
seed=SEED,
Expand Down Expand Up @@ -84,7 +83,6 @@ def test_scarlet(data_dir):
sampling_function,
LSST,
batch_size=batch_size,
stamp_size=stamp_size,
njobs=1,
add_noise="all",
seed=seed, # use same seed here
Expand Down
1 change: 0 additions & 1 deletion tests/test_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def test_measure(data_dir):
sampling_function,
survey,
batch_size=batch_size,
stamp_size=stamp_size,
njobs=1,
add_noise="all",
seed=SEED,
Expand Down
Loading

0 comments on commit 04ac79b

Please sign in to comment.