From 6888b52c27731f5d0c14018f61930ef77424ab7d Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Tue, 29 Oct 2024 13:06:53 -0400 Subject: [PATCH] Patch sky map when catalog has empty partitions (#474) * Handle empty partitions in inner skymap routine * Improve test readability --- src/lsdb/catalog/dataset/healpix_dataset.py | 3 +- src/lsdb/core/plotting/skymap.py | 8 +++-- tests/lsdb/catalog/test_catalog.py | 39 +++++++++++++++++++++ 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/src/lsdb/catalog/dataset/healpix_dataset.py b/src/lsdb/catalog/dataset/healpix_dataset.py index 28d56753..57e9a87c 100644 --- a/src/lsdb/catalog/dataset/healpix_dataset.py +++ b/src/lsdb/catalog/dataset/healpix_dataset.py @@ -336,13 +336,14 @@ def skymap_data( have the default_value as its result, as well as any pixels for which the aggregate function returns None. """ + results = {} partitions = self.to_delayed() if order is None: results = { pixel: delayed(func)(partitions[index], pixel, **kwargs) for pixel, index in self._ddf_pixel_map.items() } - else: + elif len(self.hc_structure.pixel_tree) > 0: if order < self.hc_structure.pixel_tree.get_max_depth(): raise ValueError( f"order must be greater than or equal to max order in catalog " diff --git a/src/lsdb/core/plotting/skymap.py b/src/lsdb/core/plotting/skymap.py index ef32d64b..fba6e5a7 100644 --- a/src/lsdb/core/plotting/skymap.py +++ b/src/lsdb/core/plotting/skymap.py @@ -19,6 +19,12 @@ def perform_inner_skymap( **kwargs, ) -> np.ndarray: """Splits a partition into pixels at a target order and performs a given function on the new pixels""" + delta_order = target_order - pixel.order + img = np.full(1 << 2 * delta_order, fill_value=default_value) + + if len(partition) == 0: + return img + spatial_index = partition.index.to_numpy() order_pixels = spatial_index_to_healpix(spatial_index, target_order=target_order) @@ -28,8 +34,6 @@ def apply_func(df): return func(df, HealpixPixel(target_order, p), **kwargs) gb = partition.groupby(order_pixels, sort=False).apply(apply_func) - delta_order = target_order - pixel.order - img = np.full(1 << 2 * delta_order, fill_value=default_value) min_pixel_value = pixel.pixel << 2 * delta_order img[gb.index.to_numpy() - min_pixel_value] = gb.to_numpy(na_value=default_value) return img diff --git a/tests/lsdb/catalog/test_catalog.py b/tests/lsdb/catalog/test_catalog.py index ff6576d3..1abd4aea 100644 --- a/tests/lsdb/catalog/test_catalog.py +++ b/tests/lsdb/catalog/test_catalog.py @@ -458,6 +458,45 @@ def func(df, healpix): assert np.array_equal(expected_arr, arr) +def test_skymap_histogram_order_empty(small_sky_order1_catalog): + order = 3 + + def func(df, healpix): + return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + + catalog = small_sky_order1_catalog.cone_search(0, 0, 1) + _, non_empty_partitions = catalog._get_non_empty_partitions() + assert len(non_empty_partitions) == 0 + + img = catalog.skymap_histogram(func, order) + expected_img = np.zeros(hp.order2npix(order)) + assert (img == expected_img).all() + + +def test_skymap_histogram_order_some_partitions_empty(small_sky_order1_catalog): + order = 3 + + def func(df, healpix): + return len(df) / hp.nside2pixarea(hp.order2nside(healpix.order), degrees=True) + + catalog = small_sky_order1_catalog.query("ra > 350 and dec < -50") + _, non_empty_partitions = catalog._get_non_empty_partitions() + assert 0 < len(non_empty_partitions) < catalog._ddf.npartitions + + img = catalog.skymap_histogram(func, order) + + pixel_map = catalog.skymap_data(func, order) + pixel_map = {pixel: value.compute() for pixel, value in pixel_map.items()} + expected_img = np.zeros(hp.order2npix(order)) + for pixel, value in pixel_map.items(): + dorder = order - pixel.order + start = pixel.pixel * (4**dorder) + end = (pixel.pixel + 1) * (4**dorder) + img_order_pixels = np.arange(start, end) + expected_img[img_order_pixels] = value + assert (img == expected_img).all() + + # pylint: disable=no-member def test_skymap_plot(small_sky_order1_catalog, mocker): mocker.patch("lsdb.catalog.dataset.healpix_dataset.plot_healpix_map")