diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index bb21a56..0b5238a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] - python-version: ["3.10", "3.13"] + python-version: ["3.11", "3.13"] steps: - name: Checkout uses: actions/checkout@v6 diff --git a/doc/api.rst b/doc/api.rst index 97fa1d7..419b379 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -5,11 +5,32 @@ API Reference This page provides an auto-generated summary of Xoak's API. +.. currentmodule:: xoak + +Xarray NDPointIndex tree adapters +--------------------------------- + +The following classes may be used with :py:class:`xarray.indexes.NDPointIndex`, they can +be passed as ``tree_adapter_cls`` option value via :py:meth:`xarray.Dataset.set_xindex` or +:py:meth:`xarray.DataArray.set_xindex`. + +.. autosummary:: + :toctree: _api_generated/ + + S2PointTreeAdapter + SklearnBallTreeAdapter + SklearnGeoBallTreeAdapter + SklearnKDTreeAdapter + .. currentmodule:: xarray Dataset.xoak ------------ +.. warning:: + + This API is deprecated and will be removed in a future version of Xoak. + This accessor extends :py:class:`xarray.Dataset` with all the methods and properties listed below. Proper use of this accessor should be like: @@ -40,6 +61,10 @@ properties listed below. Proper use of this accessor should be like: DataArray.xoak -------------- +.. warning:: + + This API is deprecated and will be removed in a future version of Xoak. + The accessor above is also registered for :py:class:`xarray.DataArray`. **Properties** @@ -58,39 +83,3 @@ The accessor above is also registered for :py:class:`xarray.DataArray`. DataArray.xoak.set_index DataArray.xoak.sel - -Indexes -------- - -.. currentmodule:: xoak - -.. autosummary:: - :toctree: _api_generated/ - - IndexAdapter - IndexRegistry - -**Xoak's built-in index adapters** - -.. currentmodule:: xoak.index.scipy_adapters - -.. autosummary:: - :toctree: _api_generated/ - - ScipyKDTreeAdapter - -.. currentmodule:: xoak.index.sklearn_adapters - -.. autosummary:: - :toctree: _api_generated/ - - SklearnKDTreeAdapter - SklearnBallTreeAdapter - SklearnGeoBallTreeAdapter - -.. currentmodule:: xoak.index.s2_adapters - -.. autosummary:: - :toctree: _api_generated/ - - S2PointIndexAdapter diff --git a/doc/environment.yml b/doc/environment.yml index 4ecbbf0..7186d18 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -10,7 +10,7 @@ dependencies: - numpy - pys2index - scikit-learn - - sphinx + - sphinx=8.2.3 - sphinx-autosummary-accessors - pydata-sphinx-theme=0.15.4 - sphinx-book-theme=1.1.4 diff --git a/doc/examples/custom_indexes.ipynb b/doc/examples/custom_indexes.ipynb index f5fd1d4..7d13de3 100644 --- a/doc/examples/custom_indexes.ipynb +++ b/doc/examples/custom_indexes.ipynb @@ -6,7 +6,7 @@ "source": [ "# Custom Indexes\n", "\n", - "While Xoak provides some built-in index adapters, it is easy to adapt and register new indexes. " + "Xoak provides some built-in adapters for [xarray.indexes.NDPointIndex](https://docs.xarray.dev/en/stable/generated/xarray.indexes.NDPointIndex.html) ; it is easy to create custom ones." ] }, { @@ -16,26 +16,7 @@ "outputs": [], "source": [ "import numpy as np\n", - "import xarray as xr\n", - "import xoak" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "An instance of `xoak.IndexRegistry` by default contains a collection of Xoak built-in index adapters:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ireg = xoak.IndexRegistry()\n", - "\n", - "ireg" + "import xarray as xr" ] }, { @@ -44,9 +25,7 @@ "source": [ "## Example: add a brute-force \"index\"\n", "\n", - "Every Xoak supported index is a subclass of `xoak.IndexAdapter` that must implement the `build` and `query` methods. The `IndexRegistry.register` decorator may be used to register a new index adpater.\n", - "\n", - "Let's create and register a new adapter, which simply performs brute-force nearest-neighbor lookup by computing the pairwise distances between all index and query points and finding the minimum distance. " + "This example adapter simply performs brute-force nearest-neighbor lookup by computing the pairwise distances between all index and query points and finding the minimum distance. " ] }, { @@ -55,43 +34,32 @@ "metadata": {}, "outputs": [], "source": [ + "from collections.abc import Mapping\n", + "from typing import Any\n", + "\n", "from sklearn.metrics.pairwise import pairwise_distances_argmin_min\n", + "from xarray.indexes.nd_point_index import TreeAdapter\n", "\n", "\n", - "@ireg.register('brute_force')\n", - "class BruteForceIndex(xoak.IndexAdapter):\n", + "class BruteForceTreeAdapter(TreeAdapter):\n", " \"\"\"Brute-force nearest neighbor lookup.\"\"\"\n", " \n", - " def build(self, points):\n", - " # there is no index to build here, just return the points\n", - " return points\n", - " \n", - " def query(self, index, points):\n", - " positions, distances = pairwise_distances_argmin_min(points, index)\n", - " return distances, positions\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This new index now appears in the registry:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ireg" + " def __init__(self, points: np.ndarray, options: Mapping[str, Any]):\n", + " self._index_points = points\n", + "\n", + " def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:\n", + " positions, distances = pairwise_distances_argmin_min(points, self._index_points)\n", + " return distances, positions\n", + "\n", + " def equals(self, other: \"BruteForceTreeAdapter\") -> bool:\n", + " return np.array_equal(self._index_points, other._index_points)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Let's use this index in the basic example below:" + "Let's use this adapter in the basic example below:" ] }, { @@ -113,7 +81,11 @@ ")\n", "\n", "# set the brute-force index (doesn't really build any index in this case)\n", - "ds_mesh.xoak.set_index(['meshx', 'meshy'], ireg.brute_force)\n", + "ds_mesh = ds_mesh.set_xindex(\n", + " ['meshx', 'meshy'],\n", + " xr.indexes.NDPointIndex,\n", + " tree_adapter_cls=BruteForceTreeAdapter,\n", + ")\n", "\n", "# create trajectory points\n", "ds_trajectory = xr.Dataset({\n", @@ -122,29 +94,23 @@ "})\n", "\n", "# select mesh points\n", - "ds_selection = ds_mesh.xoak.sel(\n", + "ds_selection = ds_mesh.sel(\n", " meshx=ds_trajectory.trajx,\n", - " meshy=ds_trajectory.trajy\n", + " meshy=ds_trajectory.trajy,\n", + " method=\"nearest\",\n", ")\n", "\n", "# plot results\n", "ds_trajectory.plot.scatter(x='trajx', y='trajy', c='k', alpha=0.7);\n", "ds_selection.plot.scatter(x='meshx', y='meshy', hue='field', alpha=0.9);" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:xoak_dev]", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "conda-env-xoak_dev-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -156,7 +122,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.6" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/doc/examples/dask_support.ipynb b/doc/examples/dask_support.ipynb deleted file mode 100644 index bb36bcf..0000000 --- a/doc/examples/dask_support.ipynb +++ /dev/null @@ -1,150 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Indexing / Selecting Large Data\n", - "\n", - "**Note: this feature is experimental!**\n", - "\n", - "When the dataset have chunked coordinates (dask arrays), Xoak may build the index and/or performs the selection in-parallel. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import dask\n", - "import dask.array as da\n", - "import numpy as np\n", - "import xarray as xr\n", - "import xoak\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's first create an `xarray.Dataset` of latitude / longitude points located randomly on the sphere, forming a 2-dimensional (x, y) model mesh. The array coordinates are split into 4 chunks of 250x250 elements each." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "shape = (500, 500)\n", - "chunk_shape = (250, 250)\n", - "\n", - "lat = da.random.uniform(-90, 90, size=shape, chunks=chunk_shape)\n", - "lon = da.random.uniform(-180, 180, size=shape, chunks=chunk_shape)\n", - "\n", - "field = lat + lon\n", - "\n", - "ds_mesh = xr.Dataset(\n", - " coords={'lat': (('x', 'y'), lat), 'lon': (('x', 'y'), lon)},\n", - " data_vars={'field': (('x', 'y'), field)},\n", - ")\n", - "\n", - "ds_mesh" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Xoak builds an index structure for each chunk (all coordinates must have the same chunks):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ds_mesh.xoak.set_index(['lat', 'lon'], 'sklearn_geo_balltree')\n", - "\n", - "# here returns a list of BallTree objects, one for each chunk\n", - "ds_mesh.xoak.index" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create some query data points, which may also be chunked (here 2 chunks)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "shape = (100, 10)\n", - "chunk_shape = (50, 10)\n", - "\n", - "ds_data = xr.Dataset({\n", - " 'lat': (('y', 'x'), da.random.uniform(-90, 90, size=shape, chunks=chunk_shape)),\n", - " 'lon': (('y', 'x'), da.random.uniform(-180, 180, size=shape, chunks=chunk_shape))\n", - "})\n", - "\n", - "ds_data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Queries may be perfomed in parallel using Dask. Please note, however, that some indexes may not support multi-threads and/or multi-process parallelism." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from dask.diagnostics import ProgressBar\n", - "\n", - "with ProgressBar(), dask.config.set(scheduler='processes'):\n", - " ds_selection = ds_mesh.xoak.sel(lat=ds_data.lat, lon=ds_data.lon)\n", - "\n", - "ds_selection" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python [conda env:xoak_dev]", - "language": "python", - "name": "conda-env-xoak_dev-py" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/doc/examples/index.rst b/doc/examples/index.rst index bc48c62..f872780 100644 --- a/doc/examples/index.rst +++ b/doc/examples/index.rst @@ -5,5 +5,4 @@ Examples :maxdepth: 2 introduction - dask_support custom_indexes diff --git a/doc/examples/introduction.ipynb b/doc/examples/introduction.ipynb index a0e9b63..b5e3e71 100644 --- a/doc/examples/introduction.ipynb +++ b/doc/examples/introduction.ipynb @@ -6,7 +6,9 @@ "source": [ "# Introduction to Xoak\n", "\n", - "This notebook briefly shows how to use Xoak with Xarray's [advanced indexing](http://xarray.pydata.org/en/stable/indexing.html#more-advanced-indexing) to perform point-wise selection of irrelgularly spaced data encoded in coordinates with an arbitrary number of dimensions (1, 2, ..., n-d)." + "This notebook briefly shows how to use Xoak with Xarray's [advanced indexing](http://xarray.pydata.org/en/stable/indexing.html#more-advanced-indexing) to perform point-wise selection of irrelgularly spaced data encoded in coordinates with an arbitrary number of dimensions (1, 2, ..., n-d).\n", + "\n", + "**Note**: Xoak relies on [xarray.indexes.NDPointIndex](https://docs.xarray.dev/en/stable/generated/xarray.indexes.NDPointIndex.html), which has been added in Xarray version 2025.07.1." ] }, { @@ -15,10 +17,12 @@ "metadata": {}, "outputs": [], "source": [ + "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import xarray as xr\n", "import xoak\n", - "\n" + "\n", + "xr.set_options(display_expand_indexes=True);" ] }, { @@ -59,7 +63,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We first need to build an index to allow fast point-wise selection. Xoak supports several indexes that can be used depending on the data. Here we use the `sklearn_geo_balltree` index, a wrapper around [sklearn.BallTree](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.BallTree.html#sklearn.neighbors.BallTree) using the Haversine distance metric that is suited for indexing latitude / longitude points.\n", + "We first need to build an index to allow fast point-wise selection. Xoak extends [xarray.indexes.NDPointIndex](https://docs.xarray.dev/en/stable/generated/xarray.indexes.NDPointIndex.html) with different structures available for use cases depending on the data. In the example below we use a wrapper around [sklearn.BallTree](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.BallTree.html#sklearn.neighbors.BallTree) using the Haversine distance metric that is suited for indexing latitude / longitude points.\n", "\n", "With this index, it is important to specify `lat` and `lon` in this specific order. Both the `lat` and `lon` coordinates must have exactly the same dimensions in the same order, here `('x', 'y')`." ] @@ -70,7 +74,13 @@ "metadata": {}, "outputs": [], "source": [ - "ds_mesh.xoak.set_index(['lat', 'lon'], 'sklearn_geo_balltree')" + "ds_mesh = ds_mesh.set_xindex(\n", + " ['lat', 'lon'],\n", + " xr.indexes.NDPointIndex,\n", + " tree_adapter_cls=xoak.SklearnGeoBallTreeAdapter\n", + ")\n", + "\n", + "ds_mesh" ] }, { @@ -98,9 +108,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can use `xarray.Dataset.xoak.sel()` to select the mesh points that are the closest to the trajectory vertices. It works very much like `xarray.Dataset.sel()` and returns another Dataset with the selection.\n", - "\n", - "Like for `xarray.Dataset.xoak.set_index()`, it is important here that all indexer coordinates (`latitude` and `longitude` in this example) have the exact same dimensions (here `'trajectory'`). Indexers must be given for each coordinate used to build the index above, (here `latitude` for `lat` and `longitude` for `lon`). " + "We can now simply use `xarray.Dataset.sel()` to select the mesh points that are the closest to the trajectory vertices. All indexer coordinates (`latitude` and `longitude` in this example) must have the exact same dimensions (here `'trajectory'`). Indexers must be given for each coordinate used to build the index above, (here `latitude` for `lat` and `longitude` for `lon`). " ] }, { @@ -109,9 +117,10 @@ "metadata": {}, "outputs": [], "source": [ - "ds_selection = ds_mesh.xoak.sel(\n", + "ds_selection = ds_mesh.sel(\n", " lat=ds_trajectory.latitude,\n", - " lon=ds_trajectory.longitude\n", + " lon=ds_trajectory.longitude,\n", + " method=\"nearest\",\n", ")\n", "\n", "ds_selection" @@ -130,15 +139,16 @@ "metadata": {}, "outputs": [], "source": [ - "ds_trajectory.plot.scatter(x='longitude', y='latitude', c='k', alpha=0.7);\n", - "ds_selection.plot.scatter(x='lon', y='lat', hue='field', alpha=0.9);" + "fig, ax = plt.subplots()\n", + "ds_trajectory.plot.scatter(ax=ax, x='longitude', y='latitude', c='k', alpha=0.7)\n", + "ds_selection.plot.scatter(ax=ax, x='lon', y='lat', hue='field', alpha=0.9);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Xoak also supports providing coordinates with an arbitrary number of dimensions as indexers, like in the example below with vertices of another mesh on the sphere. " + "Xarray also supports providing coordinates with an arbitrary number of dimensions as indexers, like in the example below with vertices of another mesh on the sphere. " ] }, { @@ -152,9 +162,10 @@ " 'longitude': (('x', 'y'), np.random.uniform(-180, 180, size=(10, 10)))\n", "})\n", "\n", - "ds_selection = ds_mesh.xoak.sel(\n", + "ds_selection = ds_mesh.sel(\n", " lat=ds_mesh2.latitude,\n", - " lon=ds_mesh2.longitude\n", + " lon=ds_mesh2.longitude,\n", + " method=\"nearest\",\n", ")\n", "\n", "ds_selection" @@ -166,8 +177,9 @@ "metadata": {}, "outputs": [], "source": [ - "ds_mesh2.plot.scatter(x='longitude', y='latitude', c='k', alpha=0.7);\n", - "ds_selection.plot.scatter(x='lon', y='lat', hue='field', alpha=0.9);" + "fig, ax = plt.subplots()\n", + "ds_mesh2.plot.scatter(ax=ax, x='longitude', y='latitude', c='k', alpha=0.7)\n", + "ds_selection.plot.scatter(ax=ax, x='lon', y='lat', hue='field', alpha=0.9);" ] }, { @@ -180,9 +192,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:xoak_dev]", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "conda-env-xoak_dev-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -194,7 +206,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.6" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/doc/release_notes.rst b/doc/release_notes.rst index e7e2689..35e0c61 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -3,6 +3,28 @@ Release Notes ============= +v0.2.0 (Unreleased) +------------------- + +Features +~~~~~~~~ + +- Xoak now relies on :py:class:`xarray.indexes.NDPointIndex` for point-wise + indexing of irregular data, by providing custom ``TreeAdapter`` classes. + The current functionality remains the same. See documentation examples + for more details (:pull:`44`). + +Deprecations +------------ + +- Xoak specific API :py:meth:`xarray.Dataset.xoak.set_index` and + :py:meth:`xarray.Dataset.xoak.sel` has been deprecated in favor of Xarray's + API :py:meth:`xarray.Dataset.set_xindex` and :py:meth:`xarray.Dataset.sel`. + See documentation examples for more details (:pull:`44`). +- Xoak experimental support for chunked coordinates (Dask arrays) has been + deprecated (:pull:`44`). + + v0.1.2 (20 November 2025) ------------------------- diff --git a/pyproject.toml b/pyproject.toml index 25edad7..fd74a05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ - "xarray", + "xarray>=2025.07.1", "numpy", "scipy", "dask", diff --git a/src/xoak/__init__.py b/src/xoak/__init__.py index 6fac9d2..c28a293 100644 --- a/src/xoak/__init__.py +++ b/src/xoak/__init__.py @@ -1,12 +1,22 @@ from importlib.metadata import version -from .accessor import XoakAccessor -from .index import IndexAdapter, IndexRegistry +from xoak.accessor import XoakAccessor +from xoak.index import IndexAdapter, IndexRegistry +from xoak.tree_adapters import ( + S2PointTreeAdapter, + SklearnBallTreeAdapter, + SklearnGeoBallTreeAdapter, + SklearnKDTreeAdapter, +) __all__ = [ - "XoakAccessor", "IndexAdapter", "IndexRegistry", + "SklearnBallTreeAdapter", + "SklearnGeoBallTreeAdapter", + "SklearnKDTreeAdapter", + "S2PointTreeAdapter", + "XoakAccessor", ] __version__ = version("xoak") diff --git a/src/xoak/accessor.py b/src/xoak/accessor.py index d3a1f1e..0ac1218 100644 --- a/src/xoak/accessor.py +++ b/src/xoak/accessor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from collections.abc import Hashable, Iterable, Mapping from typing import Any @@ -87,6 +88,12 @@ def set_index( If the given coordinates are chunked (Dask arrays), this method will (lazily) create a forest of index trees (one tree per chunk of the flattened coordinate arrays). + .. warning:: + This method has been deprecated. Please use the Xarray API instead, e.g., + ``ds.set_xindex([...], xarray.indexes.NDPointIndex, tree_adapter_cls=...)``. + + Support for chunked Dask coordinates has been deprecated as well. + Parameters ---------- coords : iterable @@ -118,8 +125,22 @@ def set_index( X = coords_to_point_array([self._xarray_obj[c] for c in coords]) if isinstance(X, np.ndarray): + warnings.warn( + "Setting the index via the xoak accessor is deprecated and will be removed " + f"in a future version. Instead of `.xoak.set_index({coords!r}, ...)`, " + f"use the Xarray API `.set_xindex({coords!r}, xarray.indexes.NDPointIndex, " + "tree_adapter_cls=...)`", + FutureWarning, + stacklevel=2, + ) self._index = XoakIndexWrapper(self._index_type, X, 0, **kwargs) else: + warnings.warn( + "Setting a lazy index from chunked coordinates is a deprecated experimental " + "feature and will be removed in a future version.", + FutureWarning, + stacklevel=2, + ) self._index = self._build_index_forest_delayed(X, persist=persist, **kwargs) @property @@ -237,6 +258,11 @@ def sel( ) -> xr.Dataset | xr.DataArray: """Selection based on a ball tree index. + .. warning:: + This method has been deprecated. Please use the Xarray API instead + ``ds.sel(...)`` after setting an ``xarray.indexes.NDPointIndex`` with one + of the tree adapter classes available in Xoak. + The index must have been already built using `xoak.set_index()`. It behaves mostly like :meth:`xarray.Dataset.sel` and @@ -252,6 +278,15 @@ def sel( coordinates are chunked. """ + warnings.warn( + "Data selection via `.xoak.sel()` is deprecated and will be removed in a future " + "version. Instead of `.xoak.sel(...)`, use directly the Xarray API `.sel(...)` " + "after setting an `xarray.indexes.NDPointIndex` with one of the tree adapter classes " + "avaiable in Xoak.", + FutureWarning, + stacklevel=2, + ) + if not getattr(self, "_index", False): raise ValueError( "The index(es) has/have not been built yet. Call `.xoak.set_index()` first" diff --git a/src/xoak/tests/test_accessor.py b/src/xoak/tests/test_accessor.py index 4c55eb8..92a9c50 100644 --- a/src/xoak/tests/test_accessor.py +++ b/src/xoak/tests/test_accessor.py @@ -5,6 +5,30 @@ import xoak # noqa: F401 +def test_deprecation_warnings() -> None: + ds = xr.Dataset( + coords={ + "x": (("a", "b"), [[0, 1], [2, 3]]), + "y": (("a", "b"), [[0, 1], [2, 3]]), + } + ) + indexer = xr.Dataset( + coords={ + "x": ("p", [1.2, 2.9]), + "y": ("p", [1.2, 2.9]), + } + ) + + with pytest.warns(FutureWarning): + ds.xoak.set_index(["x", "y"], "scipy_kdtree") + + with pytest.warns(FutureWarning): + ds.chunk().xoak.set_index(["x", "y"], "scipy_kdtree") + + with pytest.warns(FutureWarning): + ds.xoak.sel(x=indexer.x, y=indexer.y) + + def test_set_index_error(): ds = xr.Dataset( coords={ diff --git a/src/xoak/tests/test_s2_adapters.py b/src/xoak/tests/test_s2_adapters.py index 523e0cc..f0e3fc1 100644 --- a/src/xoak/tests/test_s2_adapters.py +++ b/src/xoak/tests/test_s2_adapters.py @@ -16,6 +16,31 @@ def test_s2point(geo_dataset, geo_indexer, geo_expected): xr.testing.assert_equal(ds_sel.load(), geo_expected.load()) +def test_ndpointindex_s2point( + geo_dataset, geo_indexer, geo_expected, dataset_array_lib, indexer_array_lib +): + # TODO: remove when refactoring fixtures without dask + if dataset_array_lib is not np or indexer_array_lib is not np: + pytest.skip() + + geo_dataset = geo_dataset.compute() + geo_indexer = geo_indexer.compute() + + # TODO: remove when https://github.com/pydata/xarray/issues/10940 is fixed + if geo_dataset.lat.ndim != 2: + pytest.skip() + + geo_dataset = geo_dataset.set_xindex( + ["lat", "lon"], xr.indexes.NDPointIndex, tree_adapter_cls=xoak.S2PointTreeAdapter + ) + ds_sel = geo_dataset.sel(lat=geo_indexer.latitude, lon=geo_indexer.longitude, method="nearest") + + xr.testing.assert_equal(ds_sel, geo_expected) + + # NDPointIndex.equals() should return True (via merge) + xr.testing.assert_identical(xr.merge([geo_dataset, geo_dataset]), geo_dataset) + + def test_s2point_sizeof(): ds = xr.Dataset(coords={"lat": ("points", [0.0, 10.0]), "lon": ("points", [-5.0, 5.0])}) points = np.array([[0.0, -5.0], [10.0, 5.0]]) diff --git a/src/xoak/tests/test_sklearn_adapters.py b/src/xoak/tests/test_sklearn_adapters.py index 87c7a6b..b16a456 100644 --- a/src/xoak/tests/test_sklearn_adapters.py +++ b/src/xoak/tests/test_sklearn_adapters.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import xarray as xr @@ -13,6 +14,32 @@ def test_sklearn_kdtree(xyz_dataset, xyz_indexer, xyz_expected): xr.testing.assert_equal(ds_sel.load(), xyz_expected.load()) +def test_ndpointindex_kdtree( + xyz_dataset, xyz_indexer, xyz_expected, dataset_array_lib, indexer_array_lib +): + # TODO: remove when refactoring fixtures without dask + if dataset_array_lib is not np or indexer_array_lib is not np: + pytest.skip() + + xyz_dataset = xyz_dataset.compute() + xyz_indexer = xyz_indexer.compute() + xyz_expected = xyz_expected.compute() + + # TODO: remove when https://github.com/pydata/xarray/issues/10940 is fixed + if xyz_dataset.x.ndim != 3: + pytest.skip() + + xyz_dataset = xyz_dataset.set_xindex( + ["x", "y", "z"], xr.indexes.NDPointIndex, tree_adapter_cls=xoak.SklearnKDTreeAdapter + ) + ds_sel = xyz_dataset.sel(x=xyz_indexer.xx, y=xyz_indexer.yy, z=xyz_indexer.zz, method="nearest") + + xr.testing.assert_equal(ds_sel, xyz_expected) + + # NDPointIndex.equals() should return True (via merge) + xr.testing.assert_identical(xr.merge([xyz_dataset, xyz_dataset]), xyz_dataset) + + def test_sklearn_kdtree_options(): ds = xr.Dataset(coords={"x": ("points", [1, 2]), "y": ("points", [1, 2])}) @@ -29,6 +56,32 @@ def test_sklearn_balltree(xyz_dataset, xyz_indexer, xyz_expected): xr.testing.assert_equal(ds_sel.load(), xyz_expected.load()) +def test_ndpointindex_balltree( + xyz_dataset, xyz_indexer, xyz_expected, dataset_array_lib, indexer_array_lib +): + # TODO: remove when refactoring fixtures without dask + if dataset_array_lib is not np or indexer_array_lib is not np: + pytest.skip() + + xyz_dataset = xyz_dataset.compute() + xyz_indexer = xyz_indexer.compute() + xyz_expected = xyz_expected.compute() + + # TODO: remove when https://github.com/pydata/xarray/issues/10940 is fixed + if xyz_dataset.x.ndim != 3: + pytest.skip() + + xyz_dataset = xyz_dataset.set_xindex( + ["x", "y", "z"], xr.indexes.NDPointIndex, tree_adapter_cls=xoak.SklearnBallTreeAdapter + ) + ds_sel = xyz_dataset.sel(x=xyz_indexer.xx, y=xyz_indexer.yy, z=xyz_indexer.zz, method="nearest") + + xr.testing.assert_equal(ds_sel, xyz_expected) + + # NDPointIndex.equals() should return True (via merge) + xr.testing.assert_identical(xr.merge([xyz_dataset, xyz_dataset]), xyz_dataset) + + def test_sklearn_balltree_options(): ds = xr.Dataset(coords={"x": ("points", [1, 2]), "y": ("points", [1, 2])}) @@ -45,6 +98,32 @@ def test_sklearn_geo_balltree(geo_dataset, geo_indexer, geo_expected): xr.testing.assert_equal(ds_sel.load(), geo_expected.load()) +def test_ndpointindex_geo_balltree( + geo_dataset, geo_indexer, geo_expected, dataset_array_lib, indexer_array_lib +): + # TODO: remove when refactoring fixtures without dask + if dataset_array_lib is not np or indexer_array_lib is not np: + pytest.skip() + + geo_dataset = geo_dataset.compute() + geo_indexer = geo_indexer.compute() + geo_expected = geo_expected.compute() + + # TODO: remove when https://github.com/pydata/xarray/issues/10940 is fixed + if geo_dataset.lat.ndim != 2: + pytest.skip() + + geo_dataset = geo_dataset.set_xindex( + ["lat", "lon"], xr.indexes.NDPointIndex, tree_adapter_cls=xoak.SklearnGeoBallTreeAdapter + ) + ds_sel = geo_dataset.sel(lat=geo_indexer.latitude, lon=geo_indexer.longitude, method="nearest") + + xr.testing.assert_equal(ds_sel, geo_expected) + + # NDPointIndex.equals() should return True (via merge) + xr.testing.assert_identical(xr.merge([geo_dataset, geo_dataset]), geo_dataset) + + def test_sklearn_geo_balltree_options(): ds = xr.Dataset(coords={"x": ("points", [1, 2]), "y": ("points", [1, 2])}) diff --git a/src/xoak/tree_adapters.py b/src/xoak/tree_adapters.py new file mode 100644 index 0000000..8964c83 --- /dev/null +++ b/src/xoak/tree_adapters.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +import numpy as np + +try: + from xarray.indexes.nd_point_index import TreeAdapter # type: ignore +except ImportError: + + class TreeAdapter: ... + + +if TYPE_CHECKING: + import pys2index + import sklearn.neighbors + + +class S2PointTreeAdapter(TreeAdapter): + """:py:class:`pys2index.S2PointIndex` adapter for :py:class:`~xarray.indexes.NDPointIndex`.""" + + _s2point_index: pys2index.S2PointIndex + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from pys2index import S2PointIndex + + self._s2point_index = S2PointIndex(points) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._s2point_index.query(points) + + def equals(self, other: S2PointTreeAdapter) -> bool: + return np.array_equal( + self._s2point_index.get_cell_ids(), other._s2point_index.get_cell_ids() + ) + + +class SklearnKDTreeAdapter(TreeAdapter): + """:py:class:`sklearn.neighbors.KDTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`.""" + + _kdtree: sklearn.neighbors.KDTree + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from sklearn.neighbors import KDTree + + self._kdtree = KDTree(points, **options) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._kdtree.query(points) + + def equals(self, other: SklearnKDTreeAdapter) -> bool: + return np.array_equal(self._kdtree.data, other._kdtree.data) + + +class SklearnBallTreeAdapter(TreeAdapter): + """:py:class:`sklearn.neighbors.BallTree` adapter for + :py:class:`~xarray.indexes.NDPointIndex`. + + """ + + _balltree: sklearn.neighbors.BallTree + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from sklearn.neighbors import BallTree + + self._balltree = BallTree(points, **options) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._balltree.query(points) + + def equals(self, other: SklearnBallTreeAdapter) -> bool: + return np.array_equal(self._balltree.data, other._balltree.data) + + +class SklearnGeoBallTreeAdapter(TreeAdapter): + """:py:class:`sklearn.neighbors.BallTree` adapter for + :py:class:`~xarray.indexes.NDPointIndex`, using the 'haversine' metric. + + It can be used for indexing a set of latitude / longitude points. + + When building the index, the coordinates must be given in the latitude, + longitude order. + + Latitude and longitude values must be given in degrees for both index and + query points (those values are converted in radians by this adapter). + + """ + + _balltree: sklearn.neighbors.BallTree + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from sklearn.neighbors import BallTree + + opts = dict(options) + opts.update({"metric": "haversine"}) + + self._balltree = BallTree(np.deg2rad(points), **opts) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._balltree.query(np.deg2rad(points)) + + def equals(self, other: SklearnGeoBallTreeAdapter) -> bool: + return np.array_equal(self._balltree.data, other._balltree.data)