From 73740defc3c0edb71a72a90c938de355da25708f Mon Sep 17 00:00:00 2001 From: connoraird Date: Thu, 6 Nov 2025 14:38:03 +0000 Subject: [PATCH 1/5] Add test_cov_clip benchmark --- tests/conftest.py | 38 ++++++++++++++++++++++++++++++++++++-- tests/test_algorithm.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 681274b..0da8030 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,32 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING import jax.numpy -import numpy # noqa: ICN001 +import numpy as np import pytest import array_api_strict +import glass.jax + if TYPE_CHECKING: from types import ModuleType + from typing import TypeAlias + + UnifiedGenerator: TypeAlias = ( + np.random.Generator | glass.jax.Generator | glass._array_api_utils.Generator # noqa: SLF001 + ) + + +# Change jax logger to only log ERROR or worse +logging.getLogger("jax").setLevel(logging.ERROR) + xp_available_backends: dict[str, ModuleType] = { "array_api_strict": array_api_strict, - "numpy": numpy, + "numpy": np, "jax.numpy": jax.numpy, } @@ -26,3 +39,24 @@ def xp(request: pytest.FixtureRequest) -> ModuleType: Access array library functions using `xp.` in tests. """ return request.param # type: ignore[no-any-return] + + +@pytest.fixture(scope="session") +def urng(xp: ModuleType) -> UnifiedGenerator: + """ + Fixture for a unified RNG interface. + + Access the relevant RNG using `urng.` in tests. + + Must be used with the `xp` fixture. Use `rng` for non array API tests. + """ + seed = 42 + backend = xp.__name__ + if backend == "jax.numpy": + return glass.jax.Generator(seed=seed) + if backend == "numpy": + return np.random.default_rng(seed=seed) + if backend == "array_api_strict": + return glass._array_api_utils.Generator(seed=seed) # noqa: SLF001 + msg = "the array backend in not supported" + raise NotImplementedError(msg) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 52b5762..81946c1 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING +import numpy as np import pytest import glass.algorithm @@ -13,6 +14,8 @@ from pytest_benchmark.fixture import BenchmarkFixture + from glass._types import UnifiedGenerator + def test_nnls(xp: ModuleType, benchmark: BenchmarkFixture) -> None: """ @@ -31,3 +34,32 @@ def test_nnls(xp: ModuleType, benchmark: BenchmarkFixture) -> None: y = a @ b res = benchmark(glass.algorithm.nnls, a, y) assert xp.linalg.vector_norm((a @ res) - y) < 1e-7 + + +@pytest.mark.parametrize("rtol", [None, 1.0]) +def test_cov_clip( + xp: ModuleType, + urng: UnifiedGenerator, + benchmark: BenchmarkFixture, + rtol: float | None, +) -> None: + """ + Benchmark test for glass.algorithm.cov_clip. + + Parameterize over rtol to ensure the most coverage possible. + """ + # prepare a random matrix + m = urng.random((4, 4)) + + # symmetric matrix + a = (m + m.T) / 2 + + # fix by clipping negative eigenvalues + cov = benchmark(glass.algorithm.cov_clip, a, rtol=rtol) + + # make sure all eigenvalues are positive + assert xp.all(xp.linalg.eigvalsh(cov) >= 0) + + if rtol is not None: + h = xp.max(xp.linalg.eigvalsh(a)) + np.testing.assert_allclose(xp.linalg.eigvalsh(cov), h, rtol=1e-6) From 18ed713a0c85029e277a8c2924707caeb9b81acf Mon Sep 17 00:00:00 2001 From: connoraird Date: Thu, 6 Nov 2025 14:48:34 +0000 Subject: [PATCH 2/5] Add test_nearcorr benchmark --- tests/test_algorithm.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 81946c1..aad6c96 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -63,3 +63,34 @@ def test_cov_clip( if rtol is not None: h = xp.max(xp.linalg.eigvalsh(a)) np.testing.assert_allclose(xp.linalg.eigvalsh(cov), h, rtol=1e-6) + + +@pytest.mark.parametrize("tol", [None, 0.0001]) +def test_nearcorr( + xp: ModuleType, + benchmark: BenchmarkFixture, + tol: float | None, +) -> None: + """ + Benchmark test for glass.algorithm.nearcorr. + + Parameterize over tol to ensure the most coverage possible. + """ + # from Higham (2002) + a = xp.asarray( + [ + [1.0, 1.0, 0.0], + [1.0, 1.0, 1.0], + [0.0, 1.0, 1.0], + ], + ) + b = xp.asarray( + [ + [1.0000, 0.7607, 0.1573], + [0.7607, 1.0000, 0.7607], + [0.1573, 0.7607, 1.0000], + ], + ) + + x = benchmark(glass.algorithm.nearcorr, a, tol=tol) + np.testing.assert_allclose(x, b, atol=0.0001) From de2d42fcb12656fa5a461825ea11b1f680987f05 Mon Sep 17 00:00:00 2001 From: connoraird Date: Thu, 6 Nov 2025 14:52:23 +0000 Subject: [PATCH 3/5] Add test_cov_nearest benchmark --- tests/test_algorithm.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index aad6c96..6ad8225 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -94,3 +94,22 @@ def test_nearcorr( x = benchmark(glass.algorithm.nearcorr, a, tol=tol) np.testing.assert_allclose(x, b, atol=0.0001) + + +def test_cov_nearest( + xp: ModuleType, + urng: UnifiedGenerator, + benchmark: BenchmarkFixture, +) -> None: + """Benchmark test for glass.algorithm.cov_nearest.""" + # prepare a random matrix + m = urng.random((4, 4)) + + # symmetric matrix + a = xp.eye(4) + (m + m.T) / 2 + + # compute covariance + cov = benchmark(glass.algorithm.cov_nearest, a) + + # make sure all eigenvalues are positive + assert xp.all(xp.linalg.eigvalsh(cov) >= 0) From 1208fb1cee6116536d57b44d8c3240f4d12969da Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 12 Nov 2025 10:17:09 +0000 Subject: [PATCH 4/5] Use rng already defined in glass --- tests/conftest.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0da8030..9d8fed8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,15 +9,12 @@ import array_api_strict -import glass.jax +import glass._array_api_utils if TYPE_CHECKING: from types import ModuleType - from typing import TypeAlias - UnifiedGenerator: TypeAlias = ( - np.random.Generator | glass.jax.Generator | glass._array_api_utils.Generator # noqa: SLF001 - ) + from glass._types import UnifiedGenerator # Change jax logger to only log ERROR or worse @@ -50,13 +47,4 @@ def urng(xp: ModuleType) -> UnifiedGenerator: Must be used with the `xp` fixture. Use `rng` for non array API tests. """ - seed = 42 - backend = xp.__name__ - if backend == "jax.numpy": - return glass.jax.Generator(seed=seed) - if backend == "numpy": - return np.random.default_rng(seed=seed) - if backend == "array_api_strict": - return glass._array_api_utils.Generator(seed=seed) # noqa: SLF001 - msg = "the array backend in not supported" - raise NotImplementedError(msg) + return glass._array_api_utils.rng_dispatcher(xp=xp) # noqa: SLF001 From 85607a2aa0ac0ce7412a4a93a5639be3cb973b0f Mon Sep 17 00:00:00 2001 From: connoraird Date: Wed, 12 Nov 2025 10:36:03 +0000 Subject: [PATCH 5/5] Switch np to numpy --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9d8fed8..9911468 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING import jax.numpy -import numpy as np +import numpy # noqa: ICN001 import pytest import array_api_strict @@ -23,7 +23,7 @@ xp_available_backends: dict[str, ModuleType] = { "array_api_strict": array_api_strict, - "numpy": np, + "numpy": numpy, "jax.numpy": jax.numpy, }