Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8a5fe45
this branch isn't fully functional yet, but let's get this in place f…
quattro Jun 26, 2024
35be0f4
this branch isn't fully functional yet, but let's get this in place f…
quattro Jun 26, 2024
d71b5a4
this branch isn't fully functional yet, but let's get this in place f…
quattro Jun 26, 2024
3b58cd1
checkpoint for handoff
quattro Sep 20, 2024
06a57de
test: tridiagonal linear operator
nahid18 Sep 23, 2024
cba2d0d
chore: remove import
nahid18 Sep 23, 2024
288328b
Merge pull request #7 from nahid18/nonlinear
quattro Sep 23, 2024
a691c30
test: identity linear operator
nahid18 Sep 30, 2024
d964667
test: tagged linear operator
nahid18 Sep 30, 2024
66a36d3
chore: added trailing space
nahid18 Sep 30, 2024
feec54a
Merge pull request #8 from nahid18/nonlinear
quattro Oct 1, 2024
5775168
test: compound operators
nahid18 Oct 2, 2024
2b67f58
test: tagged operator
nahid18 Oct 2, 2024
1214a58
Merge pull request #9 from nahid18/nonlinear
quattro Oct 2, 2024
f71250e
refactor build diagonal to a dispatcher
nahid18 Oct 4, 2024
4ac0977
feat: add, neg, div, composed
nahid18 Oct 5, 2024
f232896
Merge pull request #10 from nahid18/nonlinear
quattro Oct 7, 2024
445a135
switch back to eqxi for primitive construction. simplified few ident ops
quattro Oct 8, 2024
a391d9b
comment out filter jit
nahid18 Oct 9, 2024
8a8fdac
fixed dtype issues where compiled and run dtypes mismatch
quattro Oct 9, 2024
fe34584
Merge branch 'mancusolab:nonlinear' into nonlinear
nahid18 Oct 10, 2024
e65a14a
tangent linear operator and materialise zeros
nahid18 Oct 11, 2024
803a2d9
chore: remove import
nahid18 Oct 11, 2024
324bc09
Merge pull request #11 from nahid18/nonlinear
quattro Oct 11, 2024
dc5890b
chore: organize
nahid18 Oct 14, 2024
7f32881
chore: rearrange code
nahid18 Nov 14, 2024
ee1b3d7
fix: scale bug in XTrace
nahid18 Mar 13, 2025
b926974
Merge pull request #12 from nahid18/nonlinear
quattro Mar 13, 2025
9d9a074
Merge branch 'main' into nonlinear
nahid18 Jul 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,20 @@ key, key1, key2, key3, key4 = rdm.split(key, 5)

# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k)) # (Array(3.4099615, dtype=float32), {})
print(tx.trace(key1, operator, k, hutch)) # (Array(3.6007538, dtype=float32), {})

# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k)) # (Array(3.3033807, dtype=float32), {})
print(tx.trace(key2, operator, k, hpp)) # (Array(3.4094956, dtype=float32), {})

# XTrace estimator; default samples uniformly on n-Sphere
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k)) # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)})
print(tx.trace(key3, operator, k, xt)) # (Array(3.3030486, dtype=float32), {'std.err': Array(0.01238528, dtype=float32)})

# XNysTrace estimator; Improved performance for NSD/PSD trace estimates
operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag)
nt = tx.XNysTraceEstimator()
print(nt.estimate(key4, operator, k)) # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)})
print(tx.trace(key4, operator, k, nt)) # (Array(3.3314352, dtype=float32), {'std.err': Array(0.0006521, dtype=float32)})
```

## Documentation
Expand Down
13 changes: 8 additions & 5 deletions src/traceax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@
from importlib.metadata import PackageNotFoundError, version # pragma: no cover

from ._estimators import (
AbstractTraceEstimator as AbstractTraceEstimator,
HutchinsonEstimator as HutchinsonEstimator,
HutchPlusPlusEstimator as HutchPlusPlusEstimator,
XNysTraceEstimator as XNysTraceEstimator,
XTraceEstimator as XTraceEstimator,
AbstractEstimator as AbstractEstimator,
)
from ._samplers import (
AbstractSampler as AbstractSampler,
NormalSampler as NormalSampler,
RademacherSampler as RademacherSampler,
SphereSampler as SphereSampler,
)
from ._trace import (
HutchinsonEstimator as HutchinsonEstimator,
HutchPlusPlusEstimator as HutchPlusPlusEstimator,
trace as trace,
XNysTraceEstimator as XNysTraceEstimator,
XTraceEstimator as XTraceEstimator,
)


try:
Expand Down
278 changes: 17 additions & 261 deletions src/traceax/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,32 @@
# limitations under the License.

from abc import abstractmethod
from typing import Any
from typing import Any, Generic, TypeVar

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.scipy as jsp

from equinox import AbstractVar
from jax.numpy.linalg import norm
from jaxtyping import Array, PRNGKeyArray
from lineax import AbstractLinearOperator, is_negative_semidefinite, is_positive_semidefinite
from jaxtyping import Array, PRNGKeyArray, PyTree
from lineax import AbstractLinearOperator

from ._samplers import AbstractSampler, RademacherSampler, SphereSampler
from ._samplers import AbstractSampler


def _check_shapes(operator: AbstractLinearOperator, k: int) -> tuple[int, int]:
n_in = operator.in_size()
n_out = operator.out_size()
if n_in != n_out:
raise ValueError(f"Trace estimation requires square linear operator. Found {(n_out, n_in)}.")
_EstimatorState = TypeVar("_EstimatorState")

if k < 1:
raise ValueError(f"Trace estimation requires positive number of matvecs. Found {k}.")

return n_in, k


def _get_scale(W: Array, D: Array, n: int, k: int) -> Array:
return (n - k + 1) / (n - norm(W, axis=0) ** 2 + jnp.abs(D) ** 2)


class AbstractTraceEstimator(eqx.Module, strict=True):
class AbstractEstimator(eqx.Module, Generic[_EstimatorState], strict=True):
r"""Abstract base class for all trace estimators."""

sampler: AbstractVar[AbstractSampler]

@abstractmethod
def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]:
def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _EstimatorState:
""" """
...

@abstractmethod
def estimate(self, state: _EstimatorState, k: int) -> tuple[PyTree[Array], dict[str, Any]]:
"""Estimate the trace of `operator`.

!!! Example
Expand All @@ -59,7 +47,7 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int)
key = jax.random.PRNGKey(...)
operator = lx.MatrixLinearOperator(...)
hutch = tx.HutchinsonEstimator()
result = hutch.compute(key, operator, k=10)
result = hutch.estimate(key, operator, k=10)
# or
result = hutch(key, operator, k=10)
```
Expand All @@ -79,241 +67,9 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int)
"""
...

def __call__(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]:
def __call__(self, state: _EstimatorState, k: int) -> tuple[PyTree[Array], dict[str, Any]]:
"""An alias for `estimate`."""
return self.estimate(key, operator, k)


class HutchinsonEstimator(AbstractTraceEstimator):
r"""Girard-Hutchinson Trace Estimator:

$\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A})$,
where $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$.

"""

sampler: AbstractSampler = RademacherSampler()

def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]:
n, k = _check_shapes(operator, k)
# sample from proposed distribution
samples = self.sampler(key, n, k)

# project to k-dim space
projected = jax.vmap(operator.mv, (1,), 1)(samples)

# take the mean across estimates
trace_est = jnp.sum(projected * samples) / k

return trace_est, {}


HutchinsonEstimator.__init__.__doc__ = r"""**Arguments:**

- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][].
"""


class HutchPlusPlusEstimator(AbstractTraceEstimator):
r"""Hutch++ Trace Estimator:

Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the a _low-rank approximation_
to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for
$\Omega = [\omega_1, \dotsc, \omega_k]$.

Hutch++ improves upon Girard-Hutchinson estimator by including the trace of the residuals. Namely,
Hutch++ estimates $\text{trace}(\mathbf{A})$ as
$\text{trace}(\hat{\mathbf{A}}) - \text{trace}(\mathbf{A} - \hat{\mathbf{A}})$.

As with the Girard-Hutchinson estimator, it requires
$\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$.

"""

sampler: AbstractSampler = RademacherSampler()

def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]:
# generate an n, k matrix X
n, k = _check_shapes(operator, k)
m = k // 3

# some operators work fine with matrices in mv, some dont; this ensures they all do
mv = jax.vmap(operator.mv, (1,), 1)

# split X into 2 Xs; X1 and X2, where X1 has shape 2m, where m = k/3
samples = self.sampler(key, n, 2 * m)
X1 = samples[:, :m]
X2 = samples[:, m:]

Y = mv(X1)

# compute Q, _ = QR(Y) (orthogonal matrix)
Q, _ = jnp.linalg.qr(Y)

# compute G = X2 - Q @ (Q.T @ X2)
G = X2 - Q @ (Q.T @ X2)

# estimate trace = tr(Q.T @ A @ Q) + tr(G.T @ A @ G) / k
AQ = mv(Q)
AG = mv(G)
trace_est = jnp.sum(AQ * Q) + jnp.sum(AG * G) / (G.shape[1])

return trace_est, {}


HutchPlusPlusEstimator.__init__.__doc__ = r"""**Arguments:**

- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][].
"""


class XTraceEstimator(AbstractTraceEstimator):
r"""XTrace Trace Estimator:

Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the the _low-rank approximation_
to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for
$\Omega = [\omega_1, \dotsc, \omega_k]$.
return self.estimate(state, k)

XTrace improves upon Hutch++ estimator by enforcing *exchangeability* of sampled test-vectors,
to construct a symmetric estimation function with lower variance.

Additionally, the *improved* XTrace algorithm (i.e. `improved = True`), ensures that test-vectors
are orthogonalized against the low rank approximation $\mathbf{Q}\mathbf{Q}^* \mathbf{A}$ and
renormalized. This improved XTrace approach may provide better empirical results compared with
the non-orthogonalized version.

As with the Girard-Hutchinson estimator, it requires
$\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$.

"""

sampler: AbstractSampler = SphereSampler()
improved: bool = True

def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]:
n, k = _check_shapes(operator, k)
m = k // 2

# some operators work fine with matrices in mv, some dont; this ensures they all do
mv = jax.vmap(operator.mv, (1,), 1)

samples = self.sampler(key, n, m)
Y = mv(samples)
Q, R = jnp.linalg.qr(Y)

# solve and rescale
S = jnp.linalg.inv(R).T
s = norm(S, axis=0)
S = S / s

# working variables
Z = mv(Q)
H = Q.T @ Z
W = Q.T @ samples
T = Z.T @ samples
HW = H @ W

SW_d = jnp.sum(S * W, axis=0)
TW_d = jnp.sum(T * W, axis=0)
SHS_d = jnp.sum(S * (H @ S), axis=0)
WHW_d = jnp.sum(W * HW, axis=0)

term1 = SW_d * jnp.sum((T - H.T @ W) * S, axis=0)
term2 = (jnp.abs(SW_d) ** 2) * SHS_d
term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0)

if self.improved:
scale = _get_scale(W, SW_d, n, m)
else:
scale = 1

estimates = jnp.trace(H) * jnp.ones(m) - SHS_d + (WHW_d - TW_d + term1 + term2 + term3) * scale
trace_est = jnp.mean(estimates)
std_err = jnp.std(estimates) / jnp.sqrt(m)

return trace_est, {"std.err": std_err}


XTraceEstimator.__init__.__doc__ = r"""**Arguments:**

- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][].
- `improved`: whether to use the *improved* XTrace estimator, which rescales predicted samples.
Default is `True` (see Notes).
"""


class XNysTraceEstimator(AbstractTraceEstimator):
r"""XNysTrace Trace Estimator:

XNysTrace improves upon XTrace estimator when $\mathbf{A}$ is (negative-) positive-semidefinite, by
performing a [Nyström approximation](https://en.wikipedia.org/wiki/Low-rank_matrix_approximations#Nystr%C3%B6m_approximation),
rather than a randomized SVD (i.e., random projection followed by QR decomposition).

Like, [`traceax.XTraceEstimator`][], the *improved* XNysTrace algorithm (i.e. `improved = True`), ensures
that test-vectors are orthogonalized against the low rank approximation and renormalized.
This improved XNysTrace approach may provide better empirical results compared with the non-orthogonalized version.

As with the Girard-Hutchinson estimator, it requires
$\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$.

"""

sampler: AbstractSampler = SphereSampler()
improved: bool = True

def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]:
is_nsd = is_negative_semidefinite(operator)
if not (is_positive_semidefinite(operator) | is_nsd):
raise ValueError("`XNysTraceEstimator` may only be used for positive or negative definite linear operators")
if is_nsd:
operator = -operator

n, k = _check_shapes(operator, k)

# some operators work fine with matrices in mv, some dont; this ensures they all do
mv = jax.vmap(operator.mv, (1,), 1)

samples = self.sampler(key, n, k)
Y = mv(samples)

# shift for numerical issues
nu = jnp.finfo(Y.dtype).eps * norm(Y, "fro") / jnp.sqrt(n)
Y = Y + samples * nu
Q, R = jnp.linalg.qr(Y)

# compute and symmetrize H, then take cholesky factor
H = samples.T @ Y
C = jnp.linalg.cholesky(0.5 * (H + H.T)).T
B = jsp.linalg.solve_triangular(C.T, R.T, lower=True).T

# if improved == True
Qs, Rs = jnp.linalg.qr(samples)
Ws = Qs.T @ samples

# solve and rescale
if self.improved:
S = jnp.linalg.inv(Rs).T
s = norm(S, axis=0)
S = S / s
scale = _get_scale(Ws, jnp.sum(S * Ws, axis=0), n, k)
else:
scale = 1

W = Q.T @ samples
S = jsp.linalg.solve_triangular(C, B.T).T / jnp.sqrt(jnp.diag(jnp.linalg.inv(H)))
dSW = jnp.sum(S * W, axis=0)

estimates = norm(B, "fro") ** 2 - norm(S, axis=0) ** 2 + (jnp.abs(dSW) ** 2) * scale - nu * n
trace_est = jnp.mean(estimates)
std_err = jnp.std(estimates) / jnp.sqrt(k)
trace_est = jnp.where(is_nsd, -trace_est, trace_est)

return trace_est, {"std.err": std_err}


XNysTraceEstimator.__init__.__doc__ = r"""**Arguments:**

- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][].
- `improved`: whether to use the *improved* XNysTrace estimator, which rescales predicted samples.
Default is `True` (see Notes).
"""
@abstractmethod
def transpose(self, state: _EstimatorState) -> _EstimatorState: ...
Loading