diff --git a/tests/conftest.py b/tests/conftest.py index 681274b..9911468 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING import jax.numpy @@ -8,9 +9,18 @@ import array_api_strict +import glass._array_api_utils + if TYPE_CHECKING: from types import ModuleType + from glass._types import UnifiedGenerator + + +# Change jax logger to only log ERROR or worse +logging.getLogger("jax").setLevel(logging.ERROR) + + xp_available_backends: dict[str, ModuleType] = { "array_api_strict": array_api_strict, "numpy": numpy, @@ -26,3 +36,15 @@ def xp(request: pytest.FixtureRequest) -> ModuleType: Access array library functions using `xp.` in tests. """ return request.param # type: ignore[no-any-return] + + +@pytest.fixture(scope="session") +def urng(xp: ModuleType) -> UnifiedGenerator: + """ + Fixture for a unified RNG interface. + + Access the relevant RNG using `urng.` in tests. + + Must be used with the `xp` fixture. Use `rng` for non array API tests. + """ + return glass._array_api_utils.rng_dispatcher(xp=xp) # noqa: SLF001 diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 52b5762..6ad8225 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING +import numpy as np import pytest import glass.algorithm @@ -13,6 +14,8 @@ from pytest_benchmark.fixture import BenchmarkFixture + from glass._types import UnifiedGenerator + def test_nnls(xp: ModuleType, benchmark: BenchmarkFixture) -> None: """ @@ -31,3 +34,82 @@ def test_nnls(xp: ModuleType, benchmark: BenchmarkFixture) -> None: y = a @ b res = benchmark(glass.algorithm.nnls, a, y) assert xp.linalg.vector_norm((a @ res) - y) < 1e-7 + + +@pytest.mark.parametrize("rtol", [None, 1.0]) +def test_cov_clip( + xp: ModuleType, + urng: UnifiedGenerator, + benchmark: BenchmarkFixture, + rtol: float | None, +) -> None: + """ + Benchmark test for glass.algorithm.cov_clip. + + Parameterize over rtol to ensure the most coverage possible. + """ + # prepare a random matrix + m = urng.random((4, 4)) + + # symmetric matrix + a = (m + m.T) / 2 + + # fix by clipping negative eigenvalues + cov = benchmark(glass.algorithm.cov_clip, a, rtol=rtol) + + # make sure all eigenvalues are positive + assert xp.all(xp.linalg.eigvalsh(cov) >= 0) + + if rtol is not None: + h = xp.max(xp.linalg.eigvalsh(a)) + np.testing.assert_allclose(xp.linalg.eigvalsh(cov), h, rtol=1e-6) + + +@pytest.mark.parametrize("tol", [None, 0.0001]) +def test_nearcorr( + xp: ModuleType, + benchmark: BenchmarkFixture, + tol: float | None, +) -> None: + """ + Benchmark test for glass.algorithm.nearcorr. + + Parameterize over tol to ensure the most coverage possible. + """ + # from Higham (2002) + a = xp.asarray( + [ + [1.0, 1.0, 0.0], + [1.0, 1.0, 1.0], + [0.0, 1.0, 1.0], + ], + ) + b = xp.asarray( + [ + [1.0000, 0.7607, 0.1573], + [0.7607, 1.0000, 0.7607], + [0.1573, 0.7607, 1.0000], + ], + ) + + x = benchmark(glass.algorithm.nearcorr, a, tol=tol) + np.testing.assert_allclose(x, b, atol=0.0001) + + +def test_cov_nearest( + xp: ModuleType, + urng: UnifiedGenerator, + benchmark: BenchmarkFixture, +) -> None: + """Benchmark test for glass.algorithm.cov_nearest.""" + # prepare a random matrix + m = urng.random((4, 4)) + + # symmetric matrix + a = xp.eye(4) + (m + m.T) / 2 + + # compute covariance + cov = benchmark(glass.algorithm.cov_nearest, a) + + # make sure all eigenvalues are positive + assert xp.all(xp.linalg.eigvalsh(cov) >= 0)