diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index 67c1b807..c8c5a2fb 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -22,15 +22,15 @@ def load_small_sky(): - return lsdb.read_hats(TEST_DIR / DATA_DIR_NAME / SMALL_SKY_DIR_NAME, catalog_type=lsdb.Catalog) + return lsdb.read_hats(TEST_DIR / DATA_DIR_NAME / SMALL_SKY_DIR_NAME) def load_small_sky_order1(): - return lsdb.read_hats(TEST_DIR / DATA_DIR_NAME / SMALL_SKY_ORDER1, catalog_type=lsdb.Catalog) + return lsdb.read_hats(TEST_DIR / DATA_DIR_NAME / SMALL_SKY_ORDER1) def load_small_sky_xmatch(): - return lsdb.read_hats(TEST_DIR / DATA_DIR_NAME / SMALL_SKY_XMATCH_NAME, catalog_type=lsdb.Catalog) + return lsdb.read_hats(TEST_DIR / DATA_DIR_NAME / SMALL_SKY_XMATCH_NAME) def time_kdtree_crossmatch(): diff --git a/src/lsdb/loaders/hats/abstract_catalog_loader.py b/src/lsdb/loaders/hats/abstract_catalog_loader.py deleted file mode 100644 index 119f8cbe..00000000 --- a/src/lsdb/loaders/hats/abstract_catalog_loader.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from pathlib import Path -from typing import Generic, List, Tuple, Type - -import hats as hc -import nested_dask as nd -import nested_pandas as npd -import numpy as np -import pyarrow as pa -from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset -from hats.io.file_io import file_io -from hats.pixel_math import HealpixPixel -from hats.pixel_math.healpix_pixel_function import get_pixel_argsort -from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN -from upath import UPath - -from lsdb.catalog.catalog import DaskDFPixelMap -from lsdb.dask.divisions import get_pixels_divisions -from lsdb.loaders.hats.hats_loading_config import HatsLoadingConfig -from lsdb.types import CatalogTypeVar, HCCatalogTypeVar - - -class AbstractCatalogLoader(Generic[CatalogTypeVar]): - """Loads a HATS Dataset with the type specified by the type variable""" - - def __init__(self, path: str | Path | UPath, config: HatsLoadingConfig) -> None: - """Initializes a HatsCatalogLoader - - Args: - path: path to the root of the HATS catalog - config: options to configure how the catalog is loaded - """ - self.path = path - self.base_catalog_dir = hc.io.file_io.get_upath(self.path) - self.config = config - - @abstractmethod - def load_catalog(self) -> CatalogTypeVar | None: - """Load a dataset from the configuration specified when the loader was created - - Returns: - Dataset object of the class's type with data from the source given at loader initialization - """ - pass - - def _load_hats_catalog(self, catalog_type: Type[HCCatalogTypeVar]) -> HCCatalogTypeVar: - """Load `hats` library catalog object with catalog metadata and partition data""" - hc_catalog = catalog_type.read_hats(self.path) - if hc_catalog.schema is None: - raise ValueError( - "The catalog schema could not be loaded from metadata." - " Ensure your catalog has _common_metadata or _metadata files" - ) - return hc_catalog - - def _load_dask_df_and_map(self, catalog: HCHealpixDataset) -> Tuple[nd.NestedFrame, DaskDFPixelMap]: - """Load Dask DF from parquet files and make dict of HEALPix pixel to partition index""" - pixels = catalog.get_healpix_pixels() - ordered_pixels = np.array(pixels)[get_pixel_argsort(pixels)] - divisions = get_pixels_divisions(ordered_pixels) - ddf = self._load_df_from_pixels(catalog, ordered_pixels, divisions) - pixel_to_index_map = {pixel: index for index, pixel in enumerate(ordered_pixels)} - return ddf, pixel_to_index_map - - def _load_df_from_pixels( - self, catalog: HCHealpixDataset, ordered_pixels: List[HealpixPixel], divisions: Tuple[int, ...] | None - ) -> nd.NestedFrame: - dask_meta_schema = self._create_dask_meta_schema(catalog.schema) - if len(ordered_pixels) > 0: - return nd.NestedFrame.from_map( - read_pixel, - ordered_pixels, - catalog=catalog, - query_url_params=self.config.make_query_url_params(), - columns=self.config.columns, - divisions=divisions, - meta=dask_meta_schema, - schema=catalog.schema, - **self._get_kwargs(), - ) - return nd.NestedFrame.from_pandas(dask_meta_schema, npartitions=1) - - def _create_dask_meta_schema(self, schema: pa.Schema) -> npd.NestedFrame: - """Creates the Dask meta DataFrame from the HATS catalog schema.""" - dask_meta_schema = schema.empty_table().to_pandas(types_mapper=self.config.get_dtype_mapper()) - if ( - dask_meta_schema.index.name != SPATIAL_INDEX_COLUMN - and SPATIAL_INDEX_COLUMN in dask_meta_schema.columns - ): - dask_meta_schema = dask_meta_schema.set_index(SPATIAL_INDEX_COLUMN) - if self.config.columns is not None and SPATIAL_INDEX_COLUMN in self.config.columns: - self.config.columns.remove(SPATIAL_INDEX_COLUMN) - if self.config.columns is not None: - dask_meta_schema = dask_meta_schema[self.config.columns] - return npd.NestedFrame(dask_meta_schema) - - def _get_kwargs(self) -> dict: - """Constructs additional arguments for the `read_parquet` call""" - kwargs = dict(self.config.kwargs) - if self.config.dtype_backend is not None: - kwargs["dtype_backend"] = self.config.dtype_backend - return kwargs - - -def read_pixel( - pixel: HealpixPixel, - catalog: HCHealpixDataset, - *, - query_url_params: dict | None = None, - columns=None, - schema=None, - **kwargs, -): - """Utility method to read a single pixel's parquet file from disk.""" - if ( - columns is not None - and schema is not None - and SPATIAL_INDEX_COLUMN in schema.names - and SPATIAL_INDEX_COLUMN not in columns - ): - columns = columns + [SPATIAL_INDEX_COLUMN] - dataframe = file_io.read_parquet_file_to_pandas( - hc.io.pixel_catalog_file(catalog.catalog_base_dir, pixel, query_url_params), - columns=columns, - schema=schema, - **kwargs, - ) - - if dataframe.index.name != SPATIAL_INDEX_COLUMN and SPATIAL_INDEX_COLUMN in dataframe.columns: - dataframe = dataframe.set_index(SPATIAL_INDEX_COLUMN) - - return dataframe diff --git a/src/lsdb/loaders/hats/association_catalog_loader.py b/src/lsdb/loaders/hats/association_catalog_loader.py deleted file mode 100644 index 0895c50d..00000000 --- a/src/lsdb/loaders/hats/association_catalog_loader.py +++ /dev/null @@ -1,27 +0,0 @@ -import hats as hc -import nested_dask as nd - -from lsdb.catalog.association_catalog import AssociationCatalog -from lsdb.loaders.hats.abstract_catalog_loader import AbstractCatalogLoader - - -class AssociationCatalogLoader(AbstractCatalogLoader[AssociationCatalog]): - """Loads an HATS AssociationCatalog""" - - def load_catalog(self) -> AssociationCatalog: - """Load a catalog from the configuration specified when the loader was created - - Returns: - Catalog object with data from the source given at loader initialization - """ - hc_catalog = self._load_hats_catalog(hc.catalog.AssociationCatalog) - if hc_catalog.catalog_info.contains_leaf_files: - dask_df, dask_df_pixel_map = self._load_dask_df_and_map(hc_catalog) - else: - dask_df, dask_df_pixel_map = self._load_empty_dask_df_and_map(hc_catalog) - return AssociationCatalog(dask_df, dask_df_pixel_map, hc_catalog) - - def _load_empty_dask_df_and_map(self, hc_catalog): - dask_meta_schema = self._create_dask_meta_schema(hc_catalog.schema) - ddf = nd.NestedFrame.from_pandas(dask_meta_schema, npartitions=1) - return ddf, {} diff --git a/src/lsdb/loaders/hats/hats_catalog_loader.py b/src/lsdb/loaders/hats/hats_catalog_loader.py deleted file mode 100644 index b684dc90..00000000 --- a/src/lsdb/loaders/hats/hats_catalog_loader.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -import hats as hc - -from lsdb.catalog.catalog import Catalog, MarginCatalog -from lsdb.loaders.hats.abstract_catalog_loader import AbstractCatalogLoader -from lsdb.loaders.hats.hats_loading_config import HatsLoadingConfig -from lsdb.loaders.hats.margin_catalog_loader import MarginCatalogLoader - - -class HatsCatalogLoader(AbstractCatalogLoader[Catalog]): - """Loads a HATS formatted Catalog""" - - def load_catalog(self) -> Catalog: - """Load a catalog from the configuration specified when the loader was created - - Returns: - Catalog object with data from the source given at loader initialization - """ - hc_catalog = self._load_hats_catalog(hc.catalog.Catalog) - filtered_hc_catalog = self._filter_hats_catalog(hc_catalog) - dask_df, dask_df_pixel_map = self._load_dask_df_and_map(filtered_hc_catalog) - catalog = Catalog(dask_df, dask_df_pixel_map, filtered_hc_catalog) - if self.config.search_filter is not None: - catalog = catalog.search(self.config.search_filter) - catalog.margin = self._load_margin_catalog() - return catalog - - def _filter_hats_catalog(self, hc_catalog: hc.catalog.Catalog) -> hc.catalog.Catalog: - """Filter the catalog pixels according to the spatial filter provided at loading time. - Object and source catalogs are not allowed to be filtered to an empty catalog. If the - resulting catalog is empty an error is issued indicating that the catalog does not have - coverage for the desired region in the sky.""" - if self.config.search_filter is None: - return hc_catalog - filtered_catalog = self.config.search_filter.filter_hc_catalog(hc_catalog) - if len(filtered_catalog.get_healpix_pixels()) == 0: - raise ValueError("The selected sky region has no coverage") - return hc.catalog.Catalog( - filtered_catalog.catalog_info, - filtered_catalog.pixel_tree, - catalog_path=hc_catalog.catalog_path, - moc=filtered_catalog.moc, - schema=filtered_catalog.schema, - ) - - def _load_margin_catalog(self) -> MarginCatalog | None: - """Load the margin catalog. It can be provided using a margin catalog - instance or a path to the catalog on disk.""" - margin_catalog = None - if isinstance(self.config.margin_cache, MarginCatalog): - margin_catalog = self.config.margin_cache - if self.config.search_filter is not None: - # pylint: disable=protected-access - margin_catalog = margin_catalog.search(self.config.search_filter) - elif self.config.margin_cache is not None: - margin_catalog = MarginCatalogLoader( - str(self.config.margin_cache), - HatsLoadingConfig( - search_filter=self.config.search_filter, - columns=self.config.columns, - margin_cache=None, - dtype_backend=self.config.dtype_backend, - **self.config.kwargs, - ), - ).load_catalog() - return margin_catalog diff --git a/src/lsdb/loaders/hats/hats_loader_factory.py b/src/lsdb/loaders/hats/hats_loader_factory.py deleted file mode 100644 index 39a34c3a..00000000 --- a/src/lsdb/loaders/hats/hats_loader_factory.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Dict, Type - -from upath import UPath - -from lsdb.catalog.association_catalog import AssociationCatalog -from lsdb.catalog.catalog import Catalog -from lsdb.catalog.dataset.dataset import Dataset -from lsdb.catalog.margin_catalog import MarginCatalog -from lsdb.loaders.hats.abstract_catalog_loader import AbstractCatalogLoader, CatalogTypeVar -from lsdb.loaders.hats.association_catalog_loader import AssociationCatalogLoader -from lsdb.loaders.hats.hats_catalog_loader import HatsCatalogLoader -from lsdb.loaders.hats.hats_loading_config import HatsLoadingConfig -from lsdb.loaders.hats.margin_catalog_loader import MarginCatalogLoader - -loader_class_for_catalog_type: Dict[Type[Dataset], Type[AbstractCatalogLoader]] = { - Catalog: HatsCatalogLoader, - AssociationCatalog: AssociationCatalogLoader, - MarginCatalog: MarginCatalogLoader, -} - - -def get_loader_for_type( - catalog_type_to_use: Type[CatalogTypeVar], path: str | Path | UPath, config: HatsLoadingConfig -) -> AbstractCatalogLoader: - """Constructs a CatalogLoader that loads a Dataset of the specified type - - Args: - catalog_type_to_use (Type[Dataset]): the type of catalog to be loaded. Uses the actual type - as the input, not a string or enum value - path (UPath): the path to load the catalog from - config (HatsLoadingConfig): Additional configuration for loading the catalog - - Returns: - An initialized CatalogLoader object with the path and config specified - """ - if catalog_type_to_use not in loader_class_for_catalog_type: - raise ValueError(f"Cannot load catalog type: {str(catalog_type_to_use)}") - loader_class = loader_class_for_catalog_type[catalog_type_to_use] - return loader_class(path, config) diff --git a/src/lsdb/loaders/hats/hats_loading_config.py b/src/lsdb/loaders/hats/hats_loading_config.py index 7121e39a..5962209e 100644 --- a/src/lsdb/loaders/hats/hats_loading_config.py +++ b/src/lsdb/loaders/hats/hats_loading_config.py @@ -65,3 +65,10 @@ def make_query_url_params(self) -> dict: url_params["filters"].append(f"{filtr[0]}{filtr[1]}{filtr[2]}") return url_params + + def get_read_kwargs(self): + """Clumps existing kwargs and `dtype_backend`, if specified.""" + kwargs = dict(self.kwargs) + if self.dtype_backend is not None: + kwargs["dtype_backend"] = self.dtype_backend + return kwargs diff --git a/src/lsdb/loaders/hats/margin_catalog_loader.py b/src/lsdb/loaders/hats/margin_catalog_loader.py deleted file mode 100644 index 5678ee66..00000000 --- a/src/lsdb/loaders/hats/margin_catalog_loader.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -import hats as hc - -from lsdb.catalog.margin_catalog import MarginCatalog -from lsdb.loaders.hats.abstract_catalog_loader import AbstractCatalogLoader - - -class MarginCatalogLoader(AbstractCatalogLoader[MarginCatalog]): - """Loads an HATS MarginCatalog""" - - def load_catalog(self) -> MarginCatalog | None: - """Load a catalog from the configuration specified when the loader was created - - Returns: - Catalog object with data from the source given at loader initialization - """ - hc_catalog = self._load_hats_catalog(hc.catalog.MarginCatalog) - filtered_hc_catalog = self._filter_hats_catalog(hc_catalog) - dask_df, dask_df_pixel_map = self._load_dask_df_and_map(filtered_hc_catalog) - margin = MarginCatalog(dask_df, dask_df_pixel_map, filtered_hc_catalog) - if self.config.search_filter is not None: - margin = margin.search(self.config.search_filter) - return margin - - def _filter_hats_catalog(self, hc_catalog: hc.catalog.MarginCatalog) -> hc.catalog.MarginCatalog: - """Filter the catalog pixels according to the spatial filter provided at loading time. - Margin catalogs, unlike object and source catalogs, are allowed to be filtered to an - empty catalog. In that case, the margin catalog is considered None.""" - if self.config.search_filter is None: - return hc_catalog - filtered_catalog = self.config.search_filter.filter_hc_catalog(hc_catalog) - return hc.catalog.MarginCatalog( - filtered_catalog.catalog_info, - filtered_catalog.pixel_tree, - catalog_path=hc_catalog.catalog_path, - schema=filtered_catalog.schema, - ) diff --git a/src/lsdb/loaders/hats/read_hats.py b/src/lsdb/loaders/hats/read_hats.py index d189b620..07c38b58 100644 --- a/src/lsdb/loaders/hats/read_hats.py +++ b/src/lsdb/loaders/hats/read_hats.py @@ -1,34 +1,31 @@ from __future__ import annotations -import dataclasses from pathlib import Path -from typing import Dict, List, Type +from typing import List, Tuple import hats as hc -from hats.catalog import CatalogType, TableProperties +import nested_dask as nd +import nested_pandas as npd +import numpy as np +import pyarrow as pa +from hats.catalog import CatalogType +from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset +from hats.io.file_io import file_io +from hats.pixel_math import HealpixPixel +from hats.pixel_math.healpix_pixel_function import get_pixel_argsort +from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN from upath import UPath from lsdb.catalog.association_catalog import AssociationCatalog -from lsdb.catalog.catalog import Catalog -from lsdb.catalog.dataset.dataset import Dataset -from lsdb.catalog.margin_catalog import MarginCatalog +from lsdb.catalog.catalog import Catalog, DaskDFPixelMap, MarginCatalog from lsdb.core.search.abstract_search import AbstractSearch -from lsdb.loaders.hats.abstract_catalog_loader import CatalogTypeVar -from lsdb.loaders.hats.hats_loader_factory import get_loader_for_type +from lsdb.dask.divisions import get_pixels_divisions from lsdb.loaders.hats.hats_loading_config import HatsLoadingConfig +from lsdb.types import CatalogTypeVar -dataset_class_for_catalog_type: Dict[CatalogType, Type[Dataset]] = { - CatalogType.OBJECT: Catalog, - CatalogType.SOURCE: Catalog, - CatalogType.ASSOCIATION: AssociationCatalog, - CatalogType.MARGIN: MarginCatalog, -} - -# pylint: disable=unused-argument def read_hats( path: str | Path | UPath, - catalog_type: Type[CatalogTypeVar] | None = None, search_filter: AbstractSearch | None = None, columns: List[str] | None = None, margin_cache: MarginCatalog | str | Path | UPath | None = None, @@ -45,18 +42,12 @@ def read_hats( lsdb.read_hats( path="./my_catalog_dir", - catalog_type=lsdb.Catalog, columns=["ra","dec"], search_filter=lsdb.core.search.ConeSearch(ra, dec, radius_arcsec), ) Args: path (UPath | Path): The path that locates the root of the HATS catalog - catalog_type (Type[Dataset]): Default `None`. By default, the type of the catalog is loaded - from the catalog info and the corresponding object type is returned. Python's type hints - cannot allow a return type specified by a loaded value, so to use the correct return - type for type checking, the type of the catalog can be specified here. Use by specifying - the lsdb class for that catalog. search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied. columns (List[str]): Default `None`. The set of columns to filter the catalog on. margin_cache (MarginCatalog or path-like): The margin cache for the main catalog, @@ -69,23 +60,170 @@ def read_hats( Catalog object loaded from the given parameters """ # Creates a config object to store loading parameters from all keyword arguments. - kwd_args = locals().copy() - config_args = {field.name: kwd_args[field.name] for field in dataclasses.fields(HatsLoadingConfig)} - config = HatsLoadingConfig(**config_args) + config = HatsLoadingConfig( + search_filter=search_filter, + columns=columns, + margin_cache=margin_cache, + dtype_backend=dtype_backend, + kwargs=kwargs, + ) + + hc_catalog = hc.read_hats(path) + if hc_catalog.schema is None: + raise ValueError( + "The catalog schema could not be loaded from metadata." + " Ensure your catalog has _common_metadata or _metadata files" + ) + + catalog_type = hc_catalog.catalog_info.catalog_type + + if catalog_type in (CatalogType.OBJECT, CatalogType.SOURCE): + return _load_object_catalog(hc_catalog, config) + if catalog_type == CatalogType.MARGIN: + return _load_margin_catalog(hc_catalog, config) + if catalog_type == CatalogType.ASSOCIATION: + return _load_association_catalog(hc_catalog, config) + + raise NotImplementedError(f"Cannot load catalog of type {catalog_type}") + + +def _load_association_catalog(hc_catalog, config): + """Load a catalog from the configuration specified when the loader was created + + Returns: + Catalog object with data from the source given at loader initialization + """ + if hc_catalog.catalog_info.contains_leaf_files: + dask_df, dask_df_pixel_map = _load_dask_df_and_map(hc_catalog, config) + else: + dask_meta_schema = _create_dask_meta_schema(hc_catalog.schema, config) + dask_df = nd.NestedFrame.from_pandas(dask_meta_schema, npartitions=1) + dask_df_pixel_map = {} + return AssociationCatalog(dask_df, dask_df_pixel_map, hc_catalog) - catalog_type_to_use = _get_dataset_class_from_catalog_info(path) - if catalog_type is not None: - catalog_type_to_use = catalog_type +def _load_margin_catalog(hc_catalog, config): + """Load a catalog from the configuration specified when the loader was created - loader = get_loader_for_type(catalog_type_to_use, path, config) - return loader.load_catalog() + Returns: + Catalog object with data from the source given at loader initialization + """ + if config.search_filter: + filtered_catalog = config.search_filter.filter_hc_catalog(hc_catalog) + hc_catalog = hc.catalog.MarginCatalog( + filtered_catalog.catalog_info, + filtered_catalog.pixel_tree, + catalog_path=hc_catalog.catalog_path, + schema=filtered_catalog.schema, + moc=filtered_catalog.moc, + ) + dask_df, dask_df_pixel_map = _load_dask_df_and_map(hc_catalog, config) + margin = MarginCatalog(dask_df, dask_df_pixel_map, hc_catalog) + if config.search_filter is not None: + margin = margin.search(config.search_filter) + return margin + + +def _load_object_catalog(hc_catalog, config): + """Load a catalog from the configuration specified when the loader was created + Returns: + Catalog object with data from the source given at loader initialization + """ + if config.search_filter: + filtered_catalog = config.search_filter.filter_hc_catalog(hc_catalog) + if len(filtered_catalog.get_healpix_pixels()) == 0: + raise ValueError("The selected sky region has no coverage") + hc_catalog = hc.catalog.Catalog( + filtered_catalog.catalog_info, + filtered_catalog.pixel_tree, + catalog_path=hc_catalog.catalog_path, + moc=filtered_catalog.moc, + schema=filtered_catalog.schema, + ) -def _get_dataset_class_from_catalog_info(base_catalog_path: str | Path | UPath) -> Type[Dataset]: - base_catalog_dir = hc.io.file_io.get_upath(base_catalog_path) - catalog_info = TableProperties.read_from_dir(base_catalog_dir) - catalog_type = catalog_info.catalog_type - if catalog_type not in dataset_class_for_catalog_type: - raise NotImplementedError(f"Cannot load catalog of type {catalog_type}") - return dataset_class_for_catalog_type[catalog_type] + dask_df, dask_df_pixel_map = _load_dask_df_and_map(hc_catalog, config) + catalog = Catalog(dask_df, dask_df_pixel_map, hc_catalog) + if config.search_filter is not None: + catalog = catalog.search(config.search_filter) + if isinstance(config.margin_cache, MarginCatalog): + catalog.margin = config.margin_cache + if config.search_filter is not None: + # pylint: disable=protected-access + catalog.margin = catalog.margin.search(config.search_filter) + elif config.margin_cache is not None: + hc_catalog = hc.read_hats(config.margin_cache) + catalog.margin = _load_margin_catalog(hc_catalog, config) + return catalog + + +def _create_dask_meta_schema(schema: pa.Schema, config) -> npd.NestedFrame: + """Creates the Dask meta DataFrame from the HATS catalog schema.""" + dask_meta_schema = schema.empty_table().to_pandas(types_mapper=config.get_dtype_mapper()) + if ( + dask_meta_schema.index.name != SPATIAL_INDEX_COLUMN + and SPATIAL_INDEX_COLUMN in dask_meta_schema.columns + ): + dask_meta_schema = dask_meta_schema.set_index(SPATIAL_INDEX_COLUMN) + if config.columns is not None and SPATIAL_INDEX_COLUMN in config.columns: + config.columns.remove(SPATIAL_INDEX_COLUMN) + if config.columns is not None: + dask_meta_schema = dask_meta_schema[config.columns] + return npd.NestedFrame(dask_meta_schema) + + +def _load_dask_df_and_map(catalog: HCHealpixDataset, config) -> Tuple[nd.NestedFrame, DaskDFPixelMap]: + """Load Dask DF from parquet files and make dict of HEALPix pixel to partition index""" + pixels = catalog.get_healpix_pixels() + ordered_pixels = np.array(pixels)[get_pixel_argsort(pixels)] + divisions = get_pixels_divisions(ordered_pixels) + dask_meta_schema = _create_dask_meta_schema(catalog.schema, config) + if len(ordered_pixels) > 0: + ddf = nd.NestedFrame.from_map( + read_pixel, + ordered_pixels, + catalog=catalog, + query_url_params=config.make_query_url_params(), + columns=config.columns, + divisions=divisions, + meta=dask_meta_schema, + schema=catalog.schema, + **config.get_read_kwargs(), + ) + else: + ddf = nd.NestedFrame.from_pandas(dask_meta_schema, npartitions=1) + pixel_to_index_map = {pixel: index for index, pixel in enumerate(ordered_pixels)} + return ddf, pixel_to_index_map + + +def read_pixel( + pixel: HealpixPixel, + catalog: HCHealpixDataset, + *, + query_url_params: dict | None = None, + columns=None, + schema=None, + **kwargs, +): + """Utility method to read a single pixel's parquet file from disk. + + NB: `columns` is necessary as an argument, even if None, so that dask-expr + optimizes the execution plan.""" + if ( + columns is not None + and schema is not None + and SPATIAL_INDEX_COLUMN in schema.names + and SPATIAL_INDEX_COLUMN not in columns + ): + columns = columns + [SPATIAL_INDEX_COLUMN] + dataframe = file_io.read_parquet_file_to_pandas( + hc.io.pixel_catalog_file(catalog.catalog_base_dir, pixel, query_url_params), + columns=columns, + schema=schema, + **kwargs, + ) + + if dataframe.index.name != SPATIAL_INDEX_COLUMN and SPATIAL_INDEX_COLUMN in dataframe.columns: + dataframe = dataframe.set_index(SPATIAL_INDEX_COLUMN) + + return dataframe diff --git a/tests/conftest.py b/tests/conftest.py index 241844bc..8f90bd6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,7 +113,7 @@ def small_sky_to_order1_source_soft_dir(test_data_dir): @pytest.fixture def small_sky_hats_catalog(small_sky_dir): - return hc.catalog.Catalog.read_hats(small_sky_dir) + return hc.read_hats(small_sky_dir) @pytest.fixture @@ -123,7 +123,7 @@ def small_sky_order1_id_index_dir(test_data_dir): @pytest.fixture def small_sky_catalog(small_sky_dir): - return lsdb.read_hats(small_sky_dir, catalog_type=lsdb.catalog.Catalog) + return lsdb.read_hats(small_sky_dir) @pytest.fixture @@ -158,7 +158,7 @@ def small_sky_to_xmatch_soft_catalog(small_sky_to_xmatch_soft_dir): @pytest.fixture def small_sky_order1_hats_catalog(small_sky_order1_dir): - return hc.catalog.Catalog.read_hats(small_sky_order1_dir) + return hc.read_hats(small_sky_order1_dir) @pytest.fixture diff --git a/tests/lsdb/catalog/test_index_search.py b/tests/lsdb/catalog/test_index_search.py index eb6ce357..9fbd2493 100644 --- a/tests/lsdb/catalog/test_index_search.py +++ b/tests/lsdb/catalog/test_index_search.py @@ -1,10 +1,10 @@ import nested_dask as nd import nested_pandas as npd -from hats.catalog.index.index_catalog import IndexCatalog +from hats import read_hats def test_index_search(small_sky_order1_catalog, small_sky_order1_id_index_dir, assert_divisions_are_correct): - catalog_index = IndexCatalog.read_hats(small_sky_order1_id_index_dir) + catalog_index = read_hats(small_sky_order1_id_index_dir) # Searching for an object that does not exist index_search_catalog = small_sky_order1_catalog.index_search([900], catalog_index) assert isinstance(index_search_catalog._ddf, nd.NestedFrame) @@ -20,7 +20,7 @@ def test_index_search(small_sky_order1_catalog, small_sky_order1_id_index_dir, a def test_index_search_coarse_versus_fine(small_sky_order1_catalog, small_sky_order1_id_index_dir): - catalog_index = IndexCatalog.read_hats(small_sky_order1_id_index_dir) + catalog_index = read_hats(small_sky_order1_id_index_dir) coarse_index_search = small_sky_order1_catalog.index_search([700], catalog_index, fine=False) fine_index_search = small_sky_order1_catalog.index_search([700], catalog_index) assert coarse_index_search.get_healpix_pixels() == fine_index_search.get_healpix_pixels() diff --git a/tests/lsdb/catalog/test_margin_catalog.py b/tests/lsdb/catalog/test_margin_catalog.py index ff4583da..f35fd46e 100644 --- a/tests/lsdb/catalog/test_margin_catalog.py +++ b/tests/lsdb/catalog/test_margin_catalog.py @@ -12,7 +12,7 @@ def test_read_margin_catalog(small_sky_xmatch_margin_dir): margin = lsdb.read_hats(small_sky_xmatch_margin_dir) assert isinstance(margin, MarginCatalog) assert isinstance(margin._ddf, nd.NestedFrame) - hc_margin = hc.catalog.MarginCatalog.read_hats(small_sky_xmatch_margin_dir) + hc_margin = hc.read_hats(small_sky_xmatch_margin_dir) assert margin.hc_structure.catalog_info == hc_margin.catalog_info assert margin.hc_structure.get_healpix_pixels() == hc_margin.get_healpix_pixels() assert margin.get_healpix_pixels() == margin.hc_structure.get_healpix_pixels() diff --git a/tests/lsdb/loaders/hats/test_read_hats.py b/tests/lsdb/loaders/hats/test_read_hats.py index ff8eb542..b1c77072 100644 --- a/tests/lsdb/loaders/hats/test_read_hats.py +++ b/tests/lsdb/loaders/hats/test_read_hats.py @@ -7,7 +7,6 @@ import numpy.testing as npt import pandas as pd import pytest -from hats.catalog.index.index_catalog import IndexCatalog from hats.pixel_math import HealpixPixel from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN, compute_spatial_index from pandas.core.dtypes.base import ExtensionDtype @@ -49,6 +48,26 @@ def test_read_hats_no_pandas(small_sky_order1_no_pandas_dir, assert_divisions_ar assert_index_correct(catalog) +def test_read_hats_with_margin_extra_kwargs(small_sky_xmatch_dir, small_sky_xmatch_margin_dir): + catalog = lsdb.read_hats( + small_sky_xmatch_dir, + margin_cache=small_sky_xmatch_margin_dir, + columns=["ra", "dec"], + filters=[("ra", ">", 300)], + engine="pyarrow", + ) + assert isinstance(catalog, lsdb.Catalog) + filtered_cat = catalog.compute() + assert all(catalog.columns == ["ra", "dec"]) + assert np.all(filtered_cat["ra"] > 300) + + margin = catalog.margin + assert isinstance(margin, lsdb.MarginCatalog) + filtered_margin = margin.compute() + assert all(margin.columns == ["ra", "dec"]) + assert np.all(filtered_margin["ra"] > 300) + + def test_read_hats_with_columns(small_sky_order1_dir): filter_columns = ["ra", "dec"] catalog = lsdb.read_hats(small_sky_order1_dir, columns=filter_columns) @@ -106,7 +125,7 @@ def test_parquet_data_in_partitions_match_files(small_sky_order1_dir, small_sky_ def test_read_hats_specify_catalog_type(small_sky_catalog, small_sky_dir): - catalog = lsdb.read_hats(small_sky_dir, catalog_type=lsdb.Catalog) + catalog = lsdb.read_hats(small_sky_dir) assert isinstance(catalog, lsdb.Catalog) assert isinstance(catalog._ddf, nd.NestedFrame) pd.testing.assert_frame_equal(catalog.compute(), small_sky_catalog.compute()) @@ -115,11 +134,6 @@ def test_read_hats_specify_catalog_type(small_sky_catalog, small_sky_dir): assert isinstance(catalog.compute(), npd.NestedFrame) -def test_read_hats_specify_wrong_catalog_type(small_sky_dir): - with pytest.raises(ValueError): - lsdb.read_hats(small_sky_dir, catalog_type=int) - - def test_catalog_with_margin_object(small_sky_xmatch_dir, small_sky_xmatch_margin_catalog): catalog = lsdb.read_hats(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog) assert isinstance(catalog, lsdb.Catalog) @@ -193,7 +207,7 @@ def test_read_hats_subset_with_index_search( small_sky_order1_catalog, small_sky_order1_id_index_dir, ): - catalog_index = IndexCatalog.read_hats(small_sky_order1_id_index_dir) + catalog_index = hc.read_hats(small_sky_order1_id_index_dir) # Filtering using catalog's index_search index_search_catalog = small_sky_order1_catalog.index_search([700], catalog_index) # Filtering when calling `read_hats` @@ -217,7 +231,7 @@ def test_read_hats_subset_with_order_search(small_sky_source_catalog, small_sky_ def test_read_hats_subset_no_partitions(small_sky_order1_dir, small_sky_order1_id_index_dir): with pytest.raises(ValueError, match="no coverage"): - catalog_index = IndexCatalog.read_hats(small_sky_order1_id_index_dir) + catalog_index = hc.read_hats(small_sky_order1_id_index_dir) index_search = IndexSearch([900], catalog_index) lsdb.read_hats(small_sky_order1_dir, search_filter=index_search)