diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1d2f743b..b4107dac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ on: jobs: test: - name: py${{ matrix.versions.python-version }} ${{ matrix.versions.resolution }} + name: py${{ matrix.versions.python-version }} ${{ matrix.versions.resolution }} ${{ matrix.deps.name}} runs-on: ubuntu-latest strategy: matrix: @@ -22,6 +22,14 @@ jobs: resolution: highest - python-version: '3.12' resolution: highest + deps: + - name: minimal + value: '[dev]' + doctest: '' # doctest runs MCA and requires statsmodels + - name: complete + value: '[dev,complete]' + doctest: '--doctest-glob=README.md' + steps: - uses: actions/checkout@v4 @@ -33,11 +41,12 @@ jobs: - name: Install dependencies run: | pip install uv - uv pip install . -r pyproject.toml --system --extra dev --resolution ${{ matrix.versions.resolution }} + uv pip install .${{ matrix.deps.value }} -r pyproject.toml \ + --system --resolution ${{ matrix.versions.resolution }} - name: Execute Tests run: | - coverage run -m pytest -n auto --doctest-glob="README.md" + coverage run -m pytest -n auto ${{ matrix.deps.doctest }} coverage report -m coverage xml diff --git a/docs/content/contributing.rst b/docs/content/contributing.rst index 785c6e55..5ac5bff2 100644 --- a/docs/content/contributing.rst +++ b/docs/content/contributing.rst @@ -48,19 +48,9 @@ Using the commands below, prepare your environment: conda create -n xeofs python=3.11 rpy2 pandoc conda activate xeofs - pip install -e .[docs,dev] + pip install -e .[complete,docs,dev] -This will install all necessary dependencies, including those for development and documentation. If you're only updating the code (without modifying online documentation), you can skip the docs dependency: - -.. code-block:: bash - - pip install -e .[dev] - -On the other hand, if you're just updating documentation: - -.. code-block:: bash - - pip install -e .[docs] +This will install both core and optional dependencies, including those for specialized models, documentation, and development. Alternatively, you can skip some of the optional dependency sets (``[complete,docs,dev]``) depending on which components of the package you're working on. Additionally, install the pre-commit hooks: @@ -81,7 +71,7 @@ Before diving into your contribution, ensure your local main branch is updated: git fetch upstream git merge upstream/main -This syncs your local main branch with the latest from the primary `xeofs` repository. +This syncs your local main branch with the latest from the primary ``xeofs`` repository. 4. Create a new branch ---------------------- diff --git a/docs/content/user_guide/installation.rst b/docs/content/user_guide/installation.rst index 14252731..d6f29c27 100644 --- a/docs/content/user_guide/installation.rst +++ b/docs/content/user_guide/installation.rst @@ -1,35 +1,33 @@ Installation ------------ -Required Dependencies +Dependencies ~~~~~~~~~~~~~~~~~~~~~ -The following packages are required dependencies: +The following packages are dependencies of ``xeofs``: -**Core Dependencies** +**Core Dependencies (Required)** * Python (3.10 or higher) -* `numpy `__ -* `pandas `__ -* `xarray `__ -* `scikit-learn `__ -* `statsmodels `__ +* `numpy `__ +* `pandas `__ +* `xarray `__ +* `dask `__ +* `scikit-learn `__ +* `typing-extensions `__ +* `tqdm `__ -**For Performance** +**For Specialized Models (Optional)** -* `dask `__ -* `numba `__ +* `numba `__ +* `statsmodels `__ -**For I/O** +**For I/O (Optional)** -* `netCDF4 `__ -* `zarr `__ -* `xarray-datatree `__ +* `h5netcdf `__ +* `netCDF4 `__ +* `zarr `__ -**Miscellaneous** - -* `typing-extensions `__ -* `tqdm `__ Instructions ~~~~~~~~~~~~ @@ -46,3 +44,17 @@ or the Python package installer `pip =1.0.2", "tqdm>=4.64.0", "dask>=2023.0.1", - "statsmodels>=0.14.0", - "netCDF4>=1.5.8", - "numba>=0.57", "typing-extensions>=4.8.0", - "zarr>=2.14.0", "xarray-datatree>=0.0.12", ] [project.optional-dependencies] +complete = ["xeofs[etc,io]"] dev = [ "build>=1.0.0", "ruff>=0.3", @@ -53,6 +50,15 @@ docs = [ "ipython>=8.14", "ipykernel>=6.23", ] +etc = [ + "numba>=0.57", + "statsmodels>=0.14.0", +] +io = [ + "h5netcdf>=1.0.0", + "netcdf4>=1.5.8", + "zarr>=2.14.0", +] [project.urls] homepage = "https://github.com/xarray-contrib/xeofs" diff --git a/tests/models/cross/__init__.py b/tests/models/cross/__init__.py index e69de29b..e6ffd4c9 100644 --- a/tests/models/cross/__init__.py +++ b/tests/models/cross/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("statsmodels") diff --git a/tests/models/cross/test_cca.py b/tests/models/cross/test_cca.py index b603e9c3..b95daed6 100644 --- a/tests/models/cross/test_cca.py +++ b/tests/models/cross/test_cca.py @@ -5,6 +5,8 @@ from xeofs.cross import CCA +from ...utilities import skip_if_missing_engine + def generate_random_data(shape, lazy=False, seed=142): rng = np.random.default_rng(seed) @@ -226,11 +228,13 @@ def test_predict(): _ = cca.inverse_transform(Y=Ry_pred) -@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4", "zarr"]) def test_save_load(tmp_path, engine): """Test save/load methods in MCA class, ensuring that we can roundtrip the model and get the same results when transforming data.""" + skip_if_missing_engine(engine) + X = generate_random_data((200, 10), seed=123) Y = generate_random_data((200, 20), seed=321) diff --git a/tests/models/cross/test_cpcca.py b/tests/models/cross/test_cpcca.py index fd81282a..2f6bce83 100644 --- a/tests/models/cross/test_cpcca.py +++ b/tests/models/cross/test_cpcca.py @@ -5,6 +5,8 @@ from xeofs.cross import CPCCA +from ...utilities import skip_if_missing_engine + def generate_random_data(shape, lazy=False, seed=142): rng = np.random.default_rng(seed) @@ -274,12 +276,14 @@ def test_predict(): _ = cpcca.inverse_transform(Y=Ry_pred) -@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4", "zarr"]) @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) def test_save_load(tmp_path, engine, alpha): """Test save/load methods in MCA class, ensuring that we can roundtrip the model and get the same results when transforming data.""" + skip_if_missing_engine(engine) + X = generate_random_data((200, 10), seed=123) Y = generate_random_data((200, 20), seed=321) @@ -319,11 +323,13 @@ def test_save_load(tmp_path, engine, alpha): assert np.allclose(XYr_o[1], XYr_l[1]) -@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4", "zarr"]) @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) def test_save_load_with_data(tmp_path, engine, alpha): """Test save/load methods in CPCCA class, ensuring that we can roundtrip the model and get the same results for SCF.""" + skip_if_missing_engine(engine) + X = generate_random_data((200, 10), seed=123) Y = generate_random_data((200, 20), seed=321) diff --git a/tests/models/cross/test_hilbert_cpcca.py b/tests/models/cross/test_hilbert_cpcca.py index d30b449d..79b267c5 100644 --- a/tests/models/cross/test_hilbert_cpcca.py +++ b/tests/models/cross/test_hilbert_cpcca.py @@ -5,6 +5,8 @@ from xeofs.cross import HilbertCPCCA +from ...utilities import skip_if_missing_engine + def generate_random_data(shape, lazy=False, seed=142): rng = np.random.default_rng(seed) @@ -65,11 +67,13 @@ def test_singular_values(use_pca): # Currently, netCDF4 does not support complex numbers, so skip this test -@pytest.mark.parametrize("engine", ["zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "zarr"]) @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) def test_save_load_with_data(tmp_path, engine, alpha): """Test save/load methods in CPCCA class, ensuring that we can roundtrip the model and get the same results.""" + skip_if_missing_engine(engine) + X = generate_random_data((200, 10), seed=123) Y = generate_random_data((200, 20), seed=321) diff --git a/tests/models/cross/test_hilbert_mca_rotator.py b/tests/models/cross/test_hilbert_mca_rotator.py index c1e47525..fd4fb04b 100644 --- a/tests/models/cross/test_hilbert_mca_rotator.py +++ b/tests/models/cross/test_hilbert_mca_rotator.py @@ -5,6 +5,8 @@ # Import the classes from your modules from xeofs.cross import HilbertMCA, HilbertMCARotator +from ...utilities import skip_if_missing_engine + @pytest.fixture def mca_model(mock_data_array, dim): @@ -242,10 +244,12 @@ def test_scores_phase(mca_model, mock_data_array, dim): ], ) # Currently, netCDF4 does not support complex numbers, so skip this test -@pytest.mark.parametrize("engine", ["zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "zarr"]) def test_save_load_with_data(tmp_path, engine, mca_model): """Test save/load methods in HilbertMCARotator class, ensuring that we can roundtrip the model and get the same results.""" + skip_if_missing_engine(engine) + original = HilbertMCARotator(n_modes=2) original.fit(mca_model) diff --git a/tests/models/cross/test_mca.py b/tests/models/cross/test_mca.py index 3d063685..3fdc1b6e 100644 --- a/tests/models/cross/test_mca.py +++ b/tests/models/cross/test_mca.py @@ -4,7 +4,7 @@ from xeofs.cross import MCA -from ...utilities import data_is_dask +from ...utilities import data_is_dask, skip_if_missing_engine @pytest.fixture @@ -376,11 +376,13 @@ def test_compute(mock_dask_data_array, dim, compute): (("lon", "lat")), ], ) -@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4", "zarr"]) def test_save_load(dim, mock_data_array, tmp_path, engine): """Test save/load methods in MCA class, ensuring that we can roundtrip the model and get the same results when transforming data.""" + skip_if_missing_engine(engine) + original = MCA() original.fit(mock_data_array, mock_data_array, dim) diff --git a/tests/models/cross/test_mca_rotator.py b/tests/models/cross/test_mca_rotator.py index 9c866082..1ae7b891 100644 --- a/tests/models/cross/test_mca_rotator.py +++ b/tests/models/cross/test_mca_rotator.py @@ -5,7 +5,7 @@ # Import the classes from your modules from xeofs.cross import MCA, MCARotator -from ...utilities import data_is_dask +from ...utilities import data_is_dask, skip_if_missing_engine @pytest.fixture @@ -230,11 +230,13 @@ def test_compute(mca_model_delayed, compute): (("lon", "lat")), ], ) -@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4", "zarr"]) def test_save_load(dim, mock_data_array, tmp_path, engine): """Test save/load methods in MCA class, ensuring that we can roundtrip the model and get the same results when transforming data.""" + skip_if_missing_engine(engine) + original_unrotated = MCA() original_unrotated.fit(mock_data_array, mock_data_array, dim) diff --git a/tests/models/cross/test_rda.py b/tests/models/cross/test_rda.py index d582d30b..9fdfcfd1 100644 --- a/tests/models/cross/test_rda.py +++ b/tests/models/cross/test_rda.py @@ -5,6 +5,8 @@ from xeofs.cross import RDA +from ...utilities import skip_if_missing_engine + def generate_random_data(shape, lazy=False, seed=142): rng = np.random.default_rng(seed) @@ -226,11 +228,13 @@ def test_predict(): _ = rda.inverse_transform(Y=Ry_pred) -@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4", "zarr"]) def test_save_load(tmp_path, engine): """Test save/load methods in MCA class, ensuring that we can roundtrip the model and get the same results when transforming data.""" + skip_if_missing_engine(engine) + X = generate_random_data((200, 10), seed=123) Y = generate_random_data((200, 20), seed=321) diff --git a/tests/models/single/test_eof.py b/tests/models/single/test_eof.py index c7373152..9294f0db 100644 --- a/tests/models/single/test_eof.py +++ b/tests/models/single/test_eof.py @@ -4,6 +4,8 @@ from xeofs.single import EOF +from ...utilities import skip_if_missing_engine + def test_init(): """Tests the initialization of the EOF class""" @@ -494,11 +496,13 @@ def test_inverse_transform(dim, mock_data_array, normalized): (("lon", "lat")), ], ) -@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4", "zarr"]) def test_save_load(dim, mock_data_array, tmp_path, engine): """Test save/load methods in EOF class, ensuring that we can roundtrip the model and get the same results when transforming data.""" + skip_if_missing_engine(engine) + original = EOF() original.fit(mock_data_array, dim) diff --git a/tests/models/single/test_eof_rotator.py b/tests/models/single/test_eof_rotator.py index 22c426eb..e47ffa52 100644 --- a/tests/models/single/test_eof_rotator.py +++ b/tests/models/single/test_eof_rotator.py @@ -5,7 +5,7 @@ from xeofs.data_container import DataContainer from xeofs.single import EOF, EOFRotator -from ...utilities import data_is_dask +from ...utilities import data_is_dask, skip_if_missing_engine @pytest.fixture @@ -203,11 +203,13 @@ def test_compute(eof_model_delayed, compute): (("lon", "lat")), ], ) -@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4", "zarr"]) def test_save_load(dim, mock_data_array, tmp_path, engine): """Test save/load methods in EOF class, ensuring that we can roundtrip the model and get the same results when transforming data.""" + skip_if_missing_engine(engine) + original_unrotated = EOF() original_unrotated.fit(mock_data_array, dim) diff --git a/tests/models/single/test_gwpca.py b/tests/models/single/test_gwpca.py index c7431a67..81c75b2b 100644 --- a/tests/models/single/test_gwpca.py +++ b/tests/models/single/test_gwpca.py @@ -2,6 +2,8 @@ import xeofs as xe +pytest.importorskip("numba") + # ============================================================================= # GENERALLY VALID TEST CASES # ============================================================================= diff --git a/tests/models/single/test_pop.py b/tests/models/single/test_pop.py index fe9f9b4c..1b155d7e 100644 --- a/tests/models/single/test_pop.py +++ b/tests/models/single/test_pop.py @@ -4,6 +4,8 @@ from xeofs.single import POP +from ...utilities import skip_if_missing_engine + def test_init(): """Tests the initialization of the POP class""" @@ -153,12 +155,14 @@ def test_inverse_transform(mock_data_array): assert set(X_rec.dims) == set(mock_data_array.dims) -@pytest.mark.parametrize("engine", ["zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "zarr"]) def test_save_load(mock_data_array, tmp_path, engine): """Test save/load methods in POP class, ensuring that we can roundtrip the model and get the same results when transforming data.""" # NOTE: netcdf4 does not support complex data types, so we use only zarr here + skip_if_missing_engine(engine) + dim = "time" original = POP() original.fit(mock_data_array, dim) diff --git a/tests/models/single/test_sparse_pca.py b/tests/models/single/test_sparse_pca.py index f4a1bc56..d6b65149 100644 --- a/tests/models/single/test_sparse_pca.py +++ b/tests/models/single/test_sparse_pca.py @@ -4,6 +4,8 @@ from xeofs.single import SparsePCA +from ...utilities import skip_if_missing_engine + def test_init(): """Tests the initialization of the SparsePCA class""" @@ -483,11 +485,13 @@ def test_inverse_transform(dim, mock_data_array, normalized): (("lon", "lat")), ], ) -@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4", "zarr"]) def test_save_load(dim, mock_data_array, tmp_path, engine): """Test save/load methods in SparsePCA class, ensuring that we can roundtrip the model and get the same results when transforming data.""" + skip_if_missing_engine(engine) + original = SparsePCA() original.fit(mock_data_array, dim) diff --git a/tests/models/test_rotator_factory.py b/tests/models/test_rotator_factory.py index 78ade121..45ec2014 100644 --- a/tests/models/test_rotator_factory.py +++ b/tests/models/test_rotator_factory.py @@ -29,6 +29,7 @@ def test_create_rotator_HilbertEOF(): def test_create_rotator_MCA(): + pytest.importorskip("statsmodels") factory = RotatorFactory(n_modes=3, power=2, max_iter=1000, rtol=1e-8) MCA_instance = MCA() rotator = factory.create_rotator(MCA_instance) @@ -36,6 +37,7 @@ def test_create_rotator_MCA(): def test_create_rotator_HilbertMCA(): + pytest.importorskip("statsmodels") factory = RotatorFactory(n_modes=3, power=2, max_iter=1000, rtol=1e-8) HilbertMCA_instance = HilbertMCA() rotator = factory.create_rotator(HilbertMCA_instance) diff --git a/tests/utilities.py b/tests/utilities.py index 1243cc56..b9826eb4 100644 --- a/tests/utilities.py +++ b/tests/utilities.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import pandas as pd from xeofs.utils.data_types import ( @@ -148,3 +149,15 @@ def assert_expected_coords(data1, data2, policy="all") -> None: type(data1), type(data2) ) ) + + +def skip_if_missing_engine(engine: str): + """ + Skip save/load tests if missing the i/o backend. + """ + # xarray uses engine="netcdf4" but the package itself is "netCDF4". + mapping = {"h5netcdf": "h5netcdf", "netcdf4": "netCDF4", "zarr": "zarr"} + module = mapping.get(engine) + if module is None: + raise ValueError(f"Unrecognized engine: {engine}") + pytest.importorskip(module) diff --git a/xeofs/base_model.py b/xeofs/base_model.py index 16aae465..dffb9c5f 100644 --- a/xeofs/base_model.py +++ b/xeofs/base_model.py @@ -1,3 +1,4 @@ +import importlib import warnings from abc import ABC, abstractmethod from datetime import datetime @@ -31,6 +32,9 @@ class BaseModel(ABC): """ + extra_modules = [] + uses_complex = False + def __init__(self): # Define model parameters self._params = {} @@ -46,6 +50,10 @@ def __init__(self): ) self.attrs.update(self._params) + # Ensure necessary non-core dependencies are available + for module in self.extra_modules: + self.check_needed_module(module) + @abstractmethod def get_serialization_attrs(self) -> dict: """Get the attributes to serialize.""" @@ -135,6 +143,9 @@ def save( if not save_data: dt = insert_placeholders(dt) + if self.uses_complex and engine == "h5netcdf": + kwargs = {"invalid_netcdf": True} | kwargs + write_model_tree(dt, path, overwrite=overwrite, engine=engine, **kwargs) @classmethod @@ -187,3 +198,12 @@ def load( def _validate_loaded_data(self, X: DataArray): """Optionally check the loaded data for placeholders.""" pass + + def check_needed_module(self, module: str): + """Check if a necessary non-core dependency is available.""" + try: + importlib.import_module(module) + except ImportError: + raise ImportError( + f"Additional module {module} is required for {self.__class__.__name__}." + ) diff --git a/xeofs/cross/cpcca.py b/xeofs/cross/cpcca.py index d6cf60fe..bd6446d3 100644 --- a/xeofs/cross/cpcca.py +++ b/xeofs/cross/cpcca.py @@ -9,7 +9,6 @@ from ..linalg.decomposer import Decomposer from ..utils.data_types import DataArray, DataObject from ..utils.hilbert_transform import hilbert_transform -from ..utils.statistics import pearson_correlation from ..utils.xarray_utils import argsort_dask from .base_model_cross_set import BaseModelCrossSet @@ -121,6 +120,8 @@ class CPCCA(BaseModelCrossSet): """ + extra_modules = ["statsmodels"] + def __init__( self, n_modes: int = 2, @@ -764,6 +765,8 @@ def homogeneous_patterns(self, correction=None, alpha=0.05): p-values of the homogenous correlation patterns of `X` and `Y`. """ + from ..utils.optional.statistics import pearson_correlation + input_data1 = self.data["input_data1"] input_data2 = self.data["input_data2"] @@ -848,6 +851,8 @@ def heterogeneous_patterns(self, correction=None, alpha=0.05): p-values of the heterogenous correlation patterns of `X` and `Y`. """ + from ..utils.optional.statistics import pearson_correlation + input_data1 = self.data["input_data1"] input_data2 = self.data["input_data2"] @@ -1121,6 +1126,8 @@ class ComplexCPCCA(CPCCA): """ + uses_complex = True + def __init__( self, n_modes: int = 2, diff --git a/xeofs/single/eof.py b/xeofs/single/eof.py index 0d94be62..3075a01d 100644 --- a/xeofs/single/eof.py +++ b/xeofs/single/eof.py @@ -300,6 +300,8 @@ class ComplexEOF(EOF): """ + uses_complex = True + def __init__( self, n_modes: int = 2, diff --git a/xeofs/single/gwpca.py b/xeofs/single/gwpca.py index df087a3d..bcc4d4a0 100644 --- a/xeofs/single/gwpca.py +++ b/xeofs/single/gwpca.py @@ -9,9 +9,9 @@ VALID_CARTESIAN_Y_NAMES, VALID_LATITUDE_NAMES, VALID_LONGITUDE_NAMES, + VALID_KERNELS, + VALID_METRICS, ) -from ..utils.distance_metrics import VALID_METRICS -from ..utils.kernels import VALID_KERNELS from ..utils.sanity_checks import assert_not_complex from .base_model_single_set import BaseModelSingleSet @@ -84,6 +84,8 @@ class GWPCA(BaseModelSingleSet): """ + extra_modules = ["numba"] + def __init__( self, n_modes: int, @@ -128,7 +130,7 @@ def __init__( def _fit_algorithm(self, X: DataArray) -> Self: # Hide numba imports here to greatly speed up module import time - from ..utils.numba_utils import _local_pcas + from ..utils.optional.numba_utils import _local_pcas # Check input type assert_not_complex(X) diff --git a/xeofs/single/pop.py b/xeofs/single/pop.py index 6173c72c..057464b2 100644 --- a/xeofs/single/pop.py +++ b/xeofs/single/pop.py @@ -99,6 +99,8 @@ class POP(BaseModelSingleSet): """ + uses_complex = True + def __init__( self, n_modes: int = 2, diff --git a/xeofs/utils/constants.py b/xeofs/utils/constants.py index aab85e48..9b34b488 100644 --- a/xeofs/utils/constants.py +++ b/xeofs/utils/constants.py @@ -14,6 +14,10 @@ VALID_CARTESIAN_X_NAMES = ["x", "x_coord"] VALID_CARTESIAN_Y_NAMES = ["y", "y_coord"] +VALID_KERNELS = ["bisquare", "gaussian", "exponential"] +VALID_METRICS = ["euclidean", "haversine"] + + MULTIPLE_TESTS = [ "bonferroni", "sidak", diff --git a/xeofs/utils/optional/__init__.py b/xeofs/utils/optional/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/xeofs/utils/distance_metrics.py b/xeofs/utils/optional/distance_metrics.py similarity index 97% rename from xeofs/utils/distance_metrics.py rename to xeofs/utils/optional/distance_metrics.py index 1b800097..7c7e37aa 100644 --- a/xeofs/utils/distance_metrics.py +++ b/xeofs/utils/optional/distance_metrics.py @@ -3,9 +3,7 @@ from numba import prange from scipy.spatial.distance import cdist -from .constants import AVG_EARTH_RADIUS - -VALID_METRICS = ["euclidean", "haversine"] +from ..constants import AVG_EARTH_RADIUS def distance_matrix_bc(A, B, metric="haversine"): diff --git a/xeofs/utils/kernels.py b/xeofs/utils/optional/kernels.py similarity index 94% rename from xeofs/utils/kernels.py rename to xeofs/utils/optional/kernels.py index a01b5a30..56db0ac4 100644 --- a/xeofs/utils/kernels.py +++ b/xeofs/utils/optional/kernels.py @@ -1,8 +1,6 @@ import numpy as np import numba -VALID_KERNELS = ["bisquare", "gaussian", "exponential"] - @numba.njit(fastmath=True) def kernel_weights_nb(distance, bandwidth, kernel): diff --git a/xeofs/utils/numba_utils.py b/xeofs/utils/optional/numba_utils.py similarity index 98% rename from xeofs/utils/numba_utils.py rename to xeofs/utils/optional/numba_utils.py index 4902545a..ea9e1465 100644 --- a/xeofs/utils/numba_utils.py +++ b/xeofs/utils/optional/numba_utils.py @@ -2,8 +2,8 @@ import numba from numba import prange -from ..utils.distance_metrics import distance_nb -from ..utils.kernels import kernel_weights_nb +from ..optional.distance_metrics import distance_nb +from ..optional.kernels import kernel_weights_nb # Additional utility functions for local PCA diff --git a/xeofs/utils/statistics.py b/xeofs/utils/optional/statistics.py similarity index 99% rename from xeofs/utils/statistics.py rename to xeofs/utils/optional/statistics.py index 3df80ddf..fbacd540 100644 --- a/xeofs/utils/statistics.py +++ b/xeofs/utils/optional/statistics.py @@ -2,7 +2,7 @@ import xarray as xr from statsmodels.stats.multitest import multipletests as statsmodels_multipletests -from .constants import MULTIPLE_TESTS +from ..constants import MULTIPLE_TESTS def pearson_correlation(