Skip to content

Commit

Permalink
Split unit tests (#494)
Browse files Browse the repository at this point in the history
* BUG in blednedness calculation after we subtracted the 1.0

* split up and add some more tests
  • Loading branch information
ismael-mendoza authored Apr 26, 2024
1 parent 7ae706d commit 8a0a780
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 168 deletions.
4 changes: 2 additions & 2 deletions btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_blendedness(iso_image: np.ndarray) -> np.ndarray:
Args:
iso_image: Array of shape = (..., N, H, W) corresponding to images of isolated
galaxiesi you are calculating blendedness for.
galaxies you are calculating blendedness for.
Returns:
Array of size (..., N) corresponding to blendedness values for each individual galaxy.
Expand All @@ -70,7 +70,7 @@ def get_blendedness(iso_image: np.ndarray) -> np.ndarray:
num = np.sum(iso_image * iso_image, axis=(-1, -2))
blend = np.sum(iso_image, axis=-3)[..., None, :, :]
denom = np.sum(blend * iso_image, axis=(-1, -2))
return 1 - np.divide(num, denom, out=np.zeros_like(num), where=(num != 0))
return 1 - np.divide(num, denom, out=np.ones_like(num), where=(num != 0))


def get_snr(iso_image: np.ndarray, sky_level: float) -> np.ndarray:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_cosmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import btk
from btk.survey import Survey

SEED = 0


def test_cosmos_generator(data_dir):
"""Test the pipeline as a whole for a single deblender."""
cosmos_catalog_paths = [
data_dir / "cosmos" / "real_galaxy_catalog_23.5_example.fits",
data_dir / "cosmos" / "real_galaxy_catalog_23.5_example_fits.fits",
]
cosmos_catalog_files = [p.as_posix() for p in cosmos_catalog_paths]
catalog = btk.catalog.CosmosCatalog.from_file(cosmos_catalog_files)

_ = catalog.get_raw_catalog()

survey: Survey = btk.survey.get_surveys("LSST")
fltr = survey.get_filter("r")
assert hasattr(fltr, "psf")

stamp_size = 24.0
max_shift = 1.0
max_n_sources = 2
sampling_function = btk.sampling_functions.DefaultSampling(
max_number=max_n_sources,
min_number=1,
stamp_size=stamp_size,
max_shift=max_shift,
min_mag=20,
max_mag=21,
seed=SEED,
mag_name="MAG",
)

batch_size = 10

draw_generator = btk.draw_blends.CosmosGenerator(
catalog,
sampling_function,
survey,
batch_size=batch_size,
stamp_size=stamp_size,
njobs=1,
add_noise="all",
seed=SEED,
gal_type="real",
)

# generate batch 100 blend catalogs and images.
blend_batch = next(draw_generator)
assert len(blend_batch.catalog_list) == batch_size
assert blend_batch.blend_images.shape == (batch_size, 6, stamp_size / 0.2, stamp_size / 0.2)
iso_shape = (batch_size, max_n_sources, 6, stamp_size / 0.2, stamp_size / 0.2)
assert blend_batch.isolated_images.shape == iso_shape
97 changes: 97 additions & 0 deletions tests/test_deblenders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np

import btk
from btk.survey import Survey

SEED = 0


def test_sep(data_dir):
"""Check we always detect single bright objects."""

catalog_file = data_dir / "input_catalog.fits"
catalog = btk.catalog.CatsimCatalog.from_file(catalog_file)
survey: Survey = btk.survey.get_surveys("LSST")

# single bright galaxy=
sampling_function = btk.sampling_functions.DefaultSampling(
max_number=1,
min_number=1,
stamp_size=24.0,
max_shift=1.0,
min_mag=0,
max_mag=21,
seed=SEED,
)

assert np.sum((catalog.table["i_ab"] > 0) & (catalog.table["i_ab"] < 21)) > 100

batch_size = 100

draw_generator = btk.draw_blends.CatsimGenerator(
catalog,
sampling_function,
survey,
batch_size=batch_size,
stamp_size=24.0,
njobs=1,
add_noise="all",
seed=SEED,
)

blend_batch = next(draw_generator)
deblender = btk.deblend.SepSingleBand(max_n_sources=1, thresh=3, use_band=2)
deblend_batch = deblender(blend_batch, njobs=1)

matcher = btk.match.PixelHungarianMatcher(pixel_max_sep=5.0)

true_catalog_list = blend_batch.catalog_list
pred_catalog_list = deblend_batch.catalog_list
matching = matcher(true_catalog_list, pred_catalog_list) # matching object
tp, t, p = matching.tp, matching.t, matching.p

recall = btk.metrics.detection.Recall(batch_size)
precision = btk.metrics.detection.Precision(batch_size)

assert recall(tp, t, p) > 0.95
assert precision(tp, t, p) > 0.95


def test_scarlet(data_dir):
"""Check scarlet deblender implementation runs without too many failures."""

max_n_sources = 3
stamp_size = 24.0
seed = 0
max_shift = 2.0 # shift is only 2 arcsecs -> 10 pixels, so blends are likely.

catalog = btk.catalog.CatsimCatalog.from_file(data_dir / "input_catalog.fits")
sampling_function = btk.sampling_functions.DefaultSampling(
max_number=max_n_sources,
min_number=max_n_sources, # always 3 sources in every blend.
stamp_size=stamp_size,
max_shift=max_shift,
min_mag=24,
max_mag=25,
seed=seed,
)
LSST = btk.survey.get_surveys("LSST")

batch_size = 10

draw_generator = btk.draw_blends.CatsimGenerator(
catalog,
sampling_function,
LSST,
batch_size=batch_size,
stamp_size=stamp_size,
njobs=1,
add_noise="all",
seed=seed, # use same seed here
)

blend_batch = next(draw_generator)
deblender = btk.deblend.Scarlet(max_n_sources)
deblend_batch = deblender(blend_batch, reference_catalogs=blend_batch.catalog_list)
n_failures = np.sum([len(cat) == 0 for cat in deblend_batch.catalog_list], axis=0)
assert n_failures <= 3
33 changes: 33 additions & 0 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from astropy.table import Table

from btk.match import PixelHungarianMatcher


def test_matching():
x1 = [12.0, 31.0]
y1 = [10.0, 30.0]
x2 = [34.0, 12.1, 20.1]
y2 = [33.0, 10.1, 22.0]

t1 = Table()
t1["x_peak"] = x1
t1["y_peak"] = y1

t2 = Table()
t2["x_peak"] = x2
t2["y_peak"] = y2

catalog_list1 = [t1]
catalog_list2 = [t2]

matcher1 = PixelHungarianMatcher(pixel_max_sep=1)

match = matcher1(catalog_list1, catalog_list2)

assert match.n_true == 2
assert match.n_pred == 3
assert match.tp == 1
assert match.fp == 2

assert match.true_matches == [[0]]
assert match.pred_matches == [[1]]
87 changes: 87 additions & 0 deletions tests/test_measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Test measure functions run on simple outputs from generator and deblenders."""

import numpy as np
from galcheat.utilities import mean_sky_level

import btk
from btk.measure import get_aperture_fluxes, get_blendedness, get_ksb_ellipticity, get_snr
from btk.survey import Survey

SEED = 0


def test_measure(data_dir):
catalog_file = data_dir / "input_catalog.fits"
catalog = btk.catalog.CatsimCatalog.from_file(catalog_file)

_ = catalog.get_raw_catalog()

survey: Survey = btk.survey.get_surveys("LSST")
fltr = survey.get_filter("r")
assert hasattr(fltr, "psf")

stamp_size = 24.0
max_shift = 2.0
max_n_sources = 4
sampling_function = btk.sampling_functions.DefaultSampling(
max_number=max_n_sources,
min_number=1,
stamp_size=stamp_size,
max_shift=max_shift,
min_mag=20,
max_mag=21,
seed=SEED,
)

batch_size = 10

draw_generator = btk.draw_blends.CatsimGenerator(
catalog,
sampling_function,
survey,
batch_size=batch_size,
stamp_size=stamp_size,
njobs=1,
add_noise="all",
seed=SEED,
)

batch = next(draw_generator)
sky_level = mean_sky_level(survey, survey.get_filter("r")).to_value("electron")

# combine all centroids
xs_peak = np.zeros((batch_size, max_n_sources))
ys_peak = np.zeros((batch_size, max_n_sources))
for ii, t in enumerate(batch.catalog_list):
n_sources = len(t["x_peak"])
xs_peak[ii, :n_sources] = t["x_peak"].value
ys_peak[ii, :n_sources] = t["y_peak"].value

# aperture photometry
fluxes, fluxerr = get_aperture_fluxes(batch.blend_images[:, 2], xs_peak, ys_peak, 5, sky_level)
assert fluxes.shape == (batch_size, max_n_sources)
assert fluxerr.shape == (batch_size, max_n_sources)

# blendedness
blendedness = get_blendedness(batch.isolated_images[:, :, 2])
assert blendedness.shape == (batch_size, max_n_sources)
assert np.all(np.less_equal(blendedness, 1)) and np.all(np.greater_equal(blendedness, 0.0))

# snr
snr = get_snr(batch.isolated_images[:, :, 2], sky_level)
snr.shape == (batch_size, max_n_sources)
assert np.all(np.greater_equal(snr, 0))

# ellipticity
ellips = get_ksb_ellipticity(batch.isolated_images[:, :, 2], batch.psf[2], 0.2)
assert ellips.shape == (batch_size, max_n_sources, 2)

# zeroes if no galaxies
for ii in range(batch_size):
n_sources = len(batch.catalog_list[ii])
for jj in range(max_n_sources):
if jj >= n_sources:
print(blendedness)
assert snr[ii, jj] == 0
assert np.all(np.isnan(ellips[ii, jj]))
assert blendedness[ii, jj] == 0
Loading

0 comments on commit 8a0a780

Please sign in to comment.