Skip to content
This repository was archived by the owner on Nov 13, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import jax.numpy
Expand All @@ -8,9 +9,18 @@

import array_api_strict

import glass._array_api_utils

if TYPE_CHECKING:
from types import ModuleType

from glass._types import UnifiedGenerator


# Change jax logger to only log ERROR or worse
logging.getLogger("jax").setLevel(logging.ERROR)


xp_available_backends: dict[str, ModuleType] = {
"array_api_strict": array_api_strict,
"numpy": numpy,
Expand All @@ -26,3 +36,15 @@ def xp(request: pytest.FixtureRequest) -> ModuleType:
Access array library functions using `xp.` in tests.
"""
return request.param # type: ignore[no-any-return]


@pytest.fixture(scope="session")
def urng(xp: ModuleType) -> UnifiedGenerator:
"""
Fixture for a unified RNG interface.

Access the relevant RNG using `urng.` in tests.

Must be used with the `xp` fixture. Use `rng` for non array API tests.
"""
return glass._array_api_utils.rng_dispatcher(xp=xp) # noqa: SLF001
82 changes: 82 additions & 0 deletions tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import TYPE_CHECKING

import numpy as np
import pytest

import glass.algorithm
Expand All @@ -13,6 +14,8 @@

from pytest_benchmark.fixture import BenchmarkFixture

from glass._types import UnifiedGenerator


def test_nnls(xp: ModuleType, benchmark: BenchmarkFixture) -> None:
"""
Expand All @@ -31,3 +34,82 @@ def test_nnls(xp: ModuleType, benchmark: BenchmarkFixture) -> None:
y = a @ b
res = benchmark(glass.algorithm.nnls, a, y)
assert xp.linalg.vector_norm((a @ res) - y) < 1e-7


@pytest.mark.parametrize("rtol", [None, 1.0])
def test_cov_clip(
xp: ModuleType,
urng: UnifiedGenerator,
benchmark: BenchmarkFixture,
rtol: float | None,
) -> None:
"""
Benchmark test for glass.algorithm.cov_clip.

Parameterize over rtol to ensure the most coverage possible.
"""
# prepare a random matrix
m = urng.random((4, 4))

# symmetric matrix
a = (m + m.T) / 2

# fix by clipping negative eigenvalues
cov = benchmark(glass.algorithm.cov_clip, a, rtol=rtol)

# make sure all eigenvalues are positive
assert xp.all(xp.linalg.eigvalsh(cov) >= 0)

if rtol is not None:
h = xp.max(xp.linalg.eigvalsh(a))
np.testing.assert_allclose(xp.linalg.eigvalsh(cov), h, rtol=1e-6)


@pytest.mark.parametrize("tol", [None, 0.0001])
def test_nearcorr(
xp: ModuleType,
benchmark: BenchmarkFixture,
tol: float | None,
) -> None:
"""
Benchmark test for glass.algorithm.nearcorr.

Parameterize over tol to ensure the most coverage possible.
"""
# from Higham (2002)
a = xp.asarray(
[
[1.0, 1.0, 0.0],
[1.0, 1.0, 1.0],
[0.0, 1.0, 1.0],
],
)
b = xp.asarray(
[
[1.0000, 0.7607, 0.1573],
[0.7607, 1.0000, 0.7607],
[0.1573, 0.7607, 1.0000],
],
)

x = benchmark(glass.algorithm.nearcorr, a, tol=tol)
np.testing.assert_allclose(x, b, atol=0.0001)


def test_cov_nearest(
xp: ModuleType,
urng: UnifiedGenerator,
benchmark: BenchmarkFixture,
) -> None:
"""Benchmark test for glass.algorithm.cov_nearest."""
# prepare a random matrix
m = urng.random((4, 4))

# symmetric matrix
a = xp.eye(4) + (m + m.T) / 2

# compute covariance
cov = benchmark(glass.algorithm.cov_nearest, a)

# make sure all eigenvalues are positive
assert xp.all(xp.linalg.eigvalsh(cov) >= 0)