diff --git a/README.md b/README.md index 9f377cb..f3688c6 100644 --- a/README.md +++ b/README.md @@ -70,20 +70,20 @@ key, key1, key2, key3, key4 = rdm.split(key, 5) # Hutchinson estimator; default samples Rademacher {-1,+1} hutch = tx.HutchinsonEstimator() -print(hutch.estimate(key1, operator, k)) # (Array(3.4099615, dtype=float32), {}) +print(tx.trace(key1, operator, k, hutch)) # (Array(3.6007538, dtype=float32), {}) # Hutch++ estimator; default samples Rademacher {-1,+1} hpp = tx.HutchPlusPlusEstimator() -print(hpp.estimate(key2, operator, k)) # (Array(3.3033807, dtype=float32), {}) +print(tx.trace(key2, operator, k, hpp)) # (Array(3.4094956, dtype=float32), {}) # XTrace estimator; default samples uniformly on n-Sphere xt = tx.XTraceEstimator() -print(xt.estimate(key3, operator, k)) # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)}) +print(tx.trace(key3, operator, k, xt)) # (Array(3.3030486, dtype=float32), {'std.err': Array(0.01238528, dtype=float32)}) # XNysTrace estimator; Improved performance for NSD/PSD trace estimates operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag) nt = tx.XNysTraceEstimator() -print(nt.estimate(key4, operator, k)) # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)}) +print(tx.trace(key4, operator, k, nt)) # (Array(3.3314352, dtype=float32), {'std.err': Array(0.0006521, dtype=float32)}) ``` ## Documentation diff --git a/src/traceax/__init__.py b/src/traceax/__init__.py index 1fed170..f16823e 100644 --- a/src/traceax/__init__.py +++ b/src/traceax/__init__.py @@ -15,11 +15,7 @@ from importlib.metadata import PackageNotFoundError, version # pragma: no cover from ._estimators import ( - AbstractTraceEstimator as AbstractTraceEstimator, - HutchinsonEstimator as HutchinsonEstimator, - HutchPlusPlusEstimator as HutchPlusPlusEstimator, - XNysTraceEstimator as XNysTraceEstimator, - XTraceEstimator as XTraceEstimator, + AbstractEstimator as AbstractEstimator, ) from ._samplers import ( AbstractSampler as AbstractSampler, @@ -27,6 +23,13 @@ RademacherSampler as RademacherSampler, SphereSampler as SphereSampler, ) +from ._trace import ( + HutchinsonEstimator as HutchinsonEstimator, + HutchPlusPlusEstimator as HutchPlusPlusEstimator, + trace as trace, + XNysTraceEstimator as XNysTraceEstimator, + XTraceEstimator as XTraceEstimator, +) try: diff --git a/src/traceax/_estimators.py b/src/traceax/_estimators.py index ec93870..9c67b3a 100644 --- a/src/traceax/_estimators.py +++ b/src/traceax/_estimators.py @@ -13,44 +13,32 @@ # limitations under the License. from abc import abstractmethod -from typing import Any +from typing import Any, Generic, TypeVar import equinox as eqx -import jax -import jax.numpy as jnp -import jax.scipy as jsp from equinox import AbstractVar -from jax.numpy.linalg import norm -from jaxtyping import Array, PRNGKeyArray -from lineax import AbstractLinearOperator, is_negative_semidefinite, is_positive_semidefinite +from jaxtyping import Array, PRNGKeyArray, PyTree +from lineax import AbstractLinearOperator -from ._samplers import AbstractSampler, RademacherSampler, SphereSampler +from ._samplers import AbstractSampler -def _check_shapes(operator: AbstractLinearOperator, k: int) -> tuple[int, int]: - n_in = operator.in_size() - n_out = operator.out_size() - if n_in != n_out: - raise ValueError(f"Trace estimation requires square linear operator. Found {(n_out, n_in)}.") +_EstimatorState = TypeVar("_EstimatorState") - if k < 1: - raise ValueError(f"Trace estimation requires positive number of matvecs. Found {k}.") - return n_in, k - - -def _get_scale(W: Array, D: Array, n: int, k: int) -> Array: - return (n - k + 1) / (n - norm(W, axis=0) ** 2 + jnp.abs(D) ** 2) - - -class AbstractTraceEstimator(eqx.Module, strict=True): +class AbstractEstimator(eqx.Module, Generic[_EstimatorState], strict=True): r"""Abstract base class for all trace estimators.""" sampler: AbstractVar[AbstractSampler] @abstractmethod - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: + def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _EstimatorState: + """ """ + ... + + @abstractmethod + def estimate(self, state: _EstimatorState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: """Estimate the trace of `operator`. !!! Example @@ -59,7 +47,7 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) key = jax.random.PRNGKey(...) operator = lx.MatrixLinearOperator(...) hutch = tx.HutchinsonEstimator() - result = hutch.compute(key, operator, k=10) + result = hutch.estimate(key, operator, k=10) # or result = hutch(key, operator, k=10) ``` @@ -79,241 +67,9 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) """ ... - def __call__(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: + def __call__(self, state: _EstimatorState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: """An alias for `estimate`.""" - return self.estimate(key, operator, k) - - -class HutchinsonEstimator(AbstractTraceEstimator): - r"""Girard-Hutchinson Trace Estimator: - - $\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A})$, - where $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. - - """ - - sampler: AbstractSampler = RademacherSampler() - - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: - n, k = _check_shapes(operator, k) - # sample from proposed distribution - samples = self.sampler(key, n, k) - - # project to k-dim space - projected = jax.vmap(operator.mv, (1,), 1)(samples) - - # take the mean across estimates - trace_est = jnp.sum(projected * samples) / k - - return trace_est, {} - - -HutchinsonEstimator.__init__.__doc__ = r"""**Arguments:** - -- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. -""" - - -class HutchPlusPlusEstimator(AbstractTraceEstimator): - r"""Hutch++ Trace Estimator: - - Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the a _low-rank approximation_ - to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for - $\Omega = [\omega_1, \dotsc, \omega_k]$. - - Hutch++ improves upon Girard-Hutchinson estimator by including the trace of the residuals. Namely, - Hutch++ estimates $\text{trace}(\mathbf{A})$ as - $\text{trace}(\hat{\mathbf{A}}) - \text{trace}(\mathbf{A} - \hat{\mathbf{A}})$. - - As with the Girard-Hutchinson estimator, it requires - $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. - - """ - - sampler: AbstractSampler = RademacherSampler() - - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: - # generate an n, k matrix X - n, k = _check_shapes(operator, k) - m = k // 3 - - # some operators work fine with matrices in mv, some dont; this ensures they all do - mv = jax.vmap(operator.mv, (1,), 1) - - # split X into 2 Xs; X1 and X2, where X1 has shape 2m, where m = k/3 - samples = self.sampler(key, n, 2 * m) - X1 = samples[:, :m] - X2 = samples[:, m:] - - Y = mv(X1) - - # compute Q, _ = QR(Y) (orthogonal matrix) - Q, _ = jnp.linalg.qr(Y) - - # compute G = X2 - Q @ (Q.T @ X2) - G = X2 - Q @ (Q.T @ X2) - - # estimate trace = tr(Q.T @ A @ Q) + tr(G.T @ A @ G) / k - AQ = mv(Q) - AG = mv(G) - trace_est = jnp.sum(AQ * Q) + jnp.sum(AG * G) / (G.shape[1]) - - return trace_est, {} - - -HutchPlusPlusEstimator.__init__.__doc__ = r"""**Arguments:** - -- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. -""" - - -class XTraceEstimator(AbstractTraceEstimator): - r"""XTrace Trace Estimator: - - Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the the _low-rank approximation_ - to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for - $\Omega = [\omega_1, \dotsc, \omega_k]$. + return self.estimate(state, k) - XTrace improves upon Hutch++ estimator by enforcing *exchangeability* of sampled test-vectors, - to construct a symmetric estimation function with lower variance. - - Additionally, the *improved* XTrace algorithm (i.e. `improved = True`), ensures that test-vectors - are orthogonalized against the low rank approximation $\mathbf{Q}\mathbf{Q}^* \mathbf{A}$ and - renormalized. This improved XTrace approach may provide better empirical results compared with - the non-orthogonalized version. - - As with the Girard-Hutchinson estimator, it requires - $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. - - """ - - sampler: AbstractSampler = SphereSampler() - improved: bool = True - - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: - n, k = _check_shapes(operator, k) - m = k // 2 - - # some operators work fine with matrices in mv, some dont; this ensures they all do - mv = jax.vmap(operator.mv, (1,), 1) - - samples = self.sampler(key, n, m) - Y = mv(samples) - Q, R = jnp.linalg.qr(Y) - - # solve and rescale - S = jnp.linalg.inv(R).T - s = norm(S, axis=0) - S = S / s - - # working variables - Z = mv(Q) - H = Q.T @ Z - W = Q.T @ samples - T = Z.T @ samples - HW = H @ W - - SW_d = jnp.sum(S * W, axis=0) - TW_d = jnp.sum(T * W, axis=0) - SHS_d = jnp.sum(S * (H @ S), axis=0) - WHW_d = jnp.sum(W * HW, axis=0) - - term1 = SW_d * jnp.sum((T - H.T @ W) * S, axis=0) - term2 = (jnp.abs(SW_d) ** 2) * SHS_d - term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0) - - if self.improved: - scale = _get_scale(W, SW_d, n, m) - else: - scale = 1 - - estimates = jnp.trace(H) * jnp.ones(m) - SHS_d + (WHW_d - TW_d + term1 + term2 + term3) * scale - trace_est = jnp.mean(estimates) - std_err = jnp.std(estimates) / jnp.sqrt(m) - - return trace_est, {"std.err": std_err} - - -XTraceEstimator.__init__.__doc__ = r"""**Arguments:** - -- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. -- `improved`: whether to use the *improved* XTrace estimator, which rescales predicted samples. - Default is `True` (see Notes). -""" - - -class XNysTraceEstimator(AbstractTraceEstimator): - r"""XNysTrace Trace Estimator: - - XNysTrace improves upon XTrace estimator when $\mathbf{A}$ is (negative-) positive-semidefinite, by - performing a [Nyström approximation](https://en.wikipedia.org/wiki/Low-rank_matrix_approximations#Nystr%C3%B6m_approximation), - rather than a randomized SVD (i.e., random projection followed by QR decomposition). - - Like, [`traceax.XTraceEstimator`][], the *improved* XNysTrace algorithm (i.e. `improved = True`), ensures - that test-vectors are orthogonalized against the low rank approximation and renormalized. - This improved XNysTrace approach may provide better empirical results compared with the non-orthogonalized version. - - As with the Girard-Hutchinson estimator, it requires - $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. - - """ - - sampler: AbstractSampler = SphereSampler() - improved: bool = True - - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: - is_nsd = is_negative_semidefinite(operator) - if not (is_positive_semidefinite(operator) | is_nsd): - raise ValueError("`XNysTraceEstimator` may only be used for positive or negative definite linear operators") - if is_nsd: - operator = -operator - - n, k = _check_shapes(operator, k) - - # some operators work fine with matrices in mv, some dont; this ensures they all do - mv = jax.vmap(operator.mv, (1,), 1) - - samples = self.sampler(key, n, k) - Y = mv(samples) - - # shift for numerical issues - nu = jnp.finfo(Y.dtype).eps * norm(Y, "fro") / jnp.sqrt(n) - Y = Y + samples * nu - Q, R = jnp.linalg.qr(Y) - - # compute and symmetrize H, then take cholesky factor - H = samples.T @ Y - C = jnp.linalg.cholesky(0.5 * (H + H.T)).T - B = jsp.linalg.solve_triangular(C.T, R.T, lower=True).T - - # if improved == True - Qs, Rs = jnp.linalg.qr(samples) - Ws = Qs.T @ samples - - # solve and rescale - if self.improved: - S = jnp.linalg.inv(Rs).T - s = norm(S, axis=0) - S = S / s - scale = _get_scale(Ws, jnp.sum(S * Ws, axis=0), n, k) - else: - scale = 1 - - W = Q.T @ samples - S = jsp.linalg.solve_triangular(C, B.T).T / jnp.sqrt(jnp.diag(jnp.linalg.inv(H))) - dSW = jnp.sum(S * W, axis=0) - - estimates = norm(B, "fro") ** 2 - norm(S, axis=0) ** 2 + (jnp.abs(dSW) ** 2) * scale - nu * n - trace_est = jnp.mean(estimates) - std_err = jnp.std(estimates) / jnp.sqrt(k) - trace_est = jnp.where(is_nsd, -trace_est, trace_est) - - return trace_est, {"std.err": std_err} - - -XNysTraceEstimator.__init__.__doc__ = r"""**Arguments:** - -- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. -- `improved`: whether to use the *improved* XNysTrace estimator, which rescales predicted samples. - Default is `True` (see Notes). -""" + @abstractmethod + def transpose(self, state: _EstimatorState) -> _EstimatorState: ... \ No newline at end of file diff --git a/src/traceax/_solution.py b/src/traceax/_solution.py new file mode 100644 index 0000000..e4a4e1c --- /dev/null +++ b/src/traceax/_solution.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 MancusoLab. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import equinox as eqx + +from jaxtyping import ArrayLike, PyTree + + +class Solution(eqx.Module, strict=True): + """The solution to a stochastic estimation problem. + + **Attributes:** + + - `value`: The estimated value. + - `stats`: A dictionary containing statistics about the solution (e.g., standard error). + This may be empty if individual estimators cannot provide this information (i.e. `{}`) + - `state`: The internal state for the estimator. + """ + + value: PyTree[Any] + stats: dict[str, PyTree[ArrayLike]] + state: PyTree[Any] diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py new file mode 100644 index 0000000..7e5688e --- /dev/null +++ b/src/traceax/_trace.py @@ -0,0 +1,566 @@ +# Copyright (c) 2024 MancusoLab. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools as ft + +from typing import Any +from typing_extensions import TypeAlias + +import equinox as eqx +import equinox.internal as eqxi +import jax.lax as lax +import jax.numpy as jnp +import jax.random as rdm +import jax.scipy as jsp +import jax.tree_util as jtu +import lineax as lx + +from jax.interpreters import ad as ad, mlir as mlir +from jax.numpy.linalg import norm +from jaxtyping import Array, PRNGKeyArray, PyTree + +from ._estimators import AbstractEstimator +from ._samplers import AbstractSampler, RademacherSampler, SphereSampler +from ._solution import Solution +from ._utils import ( + _assert_false, + _check_operator, + _clip_k, + _is_none, + _is_undefined, + _remove_undefined_primal, + _to_shapedarray, + _to_struct, + _vmap_mv, + sentinel, +) + + +def _get_scale(W: Array, D: Array, n: int, k: int) -> Array: + return (n - k + 1) / (n - norm(W, axis=0) ** 2 + jnp.abs(D) ** 2) + + +_BasicTraceState: TypeAlias = tuple[PRNGKeyArray, lx.AbstractLinearOperator, int] +_PSDTraceState: TypeAlias = tuple[PRNGKeyArray, lx.AbstractLinearOperator, int, bool] + + +class HutchinsonEstimator(AbstractEstimator[_BasicTraceState], strict=True): + r"""Girard-Hutchinson Trace Estimator: + + $\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A})$, + where $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. + + """ + + sampler: AbstractSampler = RademacherSampler() + + def init(self, key: PRNGKeyArray, operator: lx.AbstractLinearOperator) -> _BasicTraceState: + n = _check_operator(operator) + return (key, operator, n) + + def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: + key, operator, n = state + + k = _clip_k(k, n) + + # sample from proposed distribution + samples = self.sampler(key, n, k) + + # project to k-dim space + projected = _vmap_mv(operator)(samples) + + # take the mean across estimates + trace_est = jnp.sum(projected * samples) / k + + return trace_est, {} + + def transpose(self, state: _BasicTraceState) -> _BasicTraceState: + key, operator, n = state + return key, operator.transpose(), n + + +HutchinsonEstimator.__init__.__doc__ = r"""**Arguments:** + +- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. +""" + + +class HutchPlusPlusEstimator(AbstractEstimator[_BasicTraceState], strict=True): + r"""Hutch++ Trace Estimator: + + Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the a _low-rank approximation_ + to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for + $\Omega = [\omega_1, \dotsc, \omega_k]$. + + Hutch++ improves upon Girard-Hutchinson estimator by including the trace of the residuals. Namely, + Hutch++ estimates $\text{trace}(\mathbf{A})$ as + $\text{trace}(\hat{\mathbf{A}}) - \text{trace}(\mathbf{A} - \hat{\mathbf{A}})$. + + As with the Girard-Hutchinson estimator, it requires + $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. + + """ + + sampler: AbstractSampler = RademacherSampler() + + def init(self, key: PRNGKeyArray, operator: lx.AbstractLinearOperator) -> _BasicTraceState: + n = _check_operator(operator) + return (key, operator, n) + + def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: + key, operator, n = state + + k = _clip_k(k, n) + m = k // 3 + + # some operators work fine with matrices in mv, some dont; this ensures they all do + mv = _vmap_mv(operator) + + # split X into 2 Xs; X1 and X2, where X1 has shape 2m, where m = k/3 + samples = self.sampler(key, n, 2 * m) + X1 = samples[:, :m] + X2 = samples[:, m:] + + Y = mv(X1) + + # compute Q, _ = QR(Y) (orthogonal matrix) + Q, _ = jnp.linalg.qr(Y) + + # compute G = X2 - Q @ (Q.T @ X2) + G = X2 - Q @ (Q.T @ X2) + + # estimate trace = tr(Q.T @ A @ Q) + tr(G.T @ A @ G) / k + AQ = mv(Q) + AG = mv(G) + trace_est = jnp.sum(AQ * Q) + jnp.sum(AG * G) / (G.shape[1]) + + return trace_est, {} + + def transpose(self, state: _BasicTraceState) -> _BasicTraceState: + key, operator, n = state + return key, operator.transpose(), n + + +HutchPlusPlusEstimator.__init__.__doc__ = r"""**Arguments:** + +- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. +""" + + +class XTraceEstimator(AbstractEstimator[_BasicTraceState], strict=True): + r"""XTrace Trace Estimator: + + Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the the _low-rank approximation_ + to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for + $\Omega = [\omega_1, \dotsc, \omega_k]$. + + XTrace improves upon Hutch++ estimator by enforcing *exchangeability* of sampled test-vectors, + to construct a symmetric estimation function with lower variance. + + Additionally, the *improved* XTrace algorithm (i.e. `improved = True`), ensures that test-vectors + are orthogonalized against the low rank approximation $\mathbf{Q}\mathbf{Q}^* \mathbf{A}$ and + renormalized. This improved XTrace approach may provide better empirical results compared with + the non-orthogonalized version. + + As with the Girard-Hutchinson estimator, it requires + $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. + + """ + + sampler: AbstractSampler = SphereSampler() + improved: bool = True + + def init(self, key: PRNGKeyArray, operator: lx.AbstractLinearOperator) -> _BasicTraceState: + n = _check_operator(operator) + return (key, operator, n) + + def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: + key, operator, n = state + + k = _clip_k(k, n) + m = k // 2 + + # some operators work fine with matrices in mv, some dont; this ensures they all do + mv = _vmap_mv(operator) + + samples = self.sampler(key, n, m) + Y = mv(samples) + Q, R = jnp.linalg.qr(Y) + + # solve and rescale + S = jnp.linalg.inv(R).T + s = norm(S, axis=0) + S = S / s + + # working variables + Z = mv(Q) + H = Q.T @ Z + W = Q.T @ samples + T = Z.T @ samples + HW = H @ W + + SW_d = jnp.sum(S * W, axis=0) + TW_d = jnp.sum(T * W, axis=0) + SHS_d = jnp.sum(S * (H @ S), axis=0) + WHW_d = jnp.sum(W * HW, axis=0) + + term1 = SW_d * jnp.sum((T - H.T @ W) * S, axis=0) + term2 = (jnp.abs(SW_d) ** 2) * SHS_d + term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0) + + if self.improved: + scale = _get_scale(W, SW_d, n, m) + else: + scale = 1 + + estimates = jnp.trace(H) * jnp.ones(m) - SHS_d + (WHW_d - TW_d + term1 + term2 + term3) * scale + trace_est = jnp.mean(estimates) + std_err = jnp.std(estimates) / jnp.sqrt(m) + + return trace_est, {"std.err": std_err} + + def transpose(self, state: _BasicTraceState) -> _BasicTraceState: + key, operator, n = state + return key, operator.transpose(), n + + +XTraceEstimator.__init__.__doc__ = r"""**Arguments:** + +- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. +- `improved`: whether to use the *improved* XTrace estimator, which rescales predicted samples. + Default is `True` (see Notes). +""" + + +class XNysTraceEstimator(AbstractEstimator[_PSDTraceState], strict=True): + r"""XNysTrace Trace Estimator: + + XNysTrace improves upon XTrace estimator when $\mathbf{A}$ is (negative-) positive-semidefinite, by + performing a [Nyström approximation](https://en.wikipedia.org/wiki/Low-rank_matrix_approximations#Nystr%C3%B6m_approximation), + rather than a randomized SVD (i.e., random projection followed by QR decomposition). + + Like, [`traceax.XTraceEstimator`][], the *improved* XNysTrace algorithm (i.e. `improved = True`), ensures + that test-vectors are orthogonalized against the low rank approximation and renormalized. + This improved XNysTrace approach may provide better empirical results compared with the non-orthogonalized version. + + As with the Girard-Hutchinson estimator, it requires + $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. + + """ + + sampler: AbstractSampler = SphereSampler() + improved: bool = True + + def init(self, key: PRNGKeyArray, operator: lx.AbstractLinearOperator) -> _PSDTraceState: + n = _check_operator(operator) + is_nsd = lx.is_negative_semidefinite(operator) + if not (lx.is_positive_semidefinite(operator) | is_nsd): + raise ValueError("`XNysTraceEstimator` may only be used for positive or negative definite linear operators") + if is_nsd: + operator = -operator + + return (key, operator, n, is_nsd) + + def estimate(self, state: _PSDTraceState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: + key, operator, n, is_nsd = state + + k = _clip_k(k, n) + + # some operators work fine with matrices in mv, some dont; this ensures they all do + mv = _vmap_mv(operator) + + samples = self.sampler(key, n, k) + Y = mv(samples) + + # shift for numerical issues + nu = jnp.finfo(Y.dtype).eps * norm(Y, "fro") / jnp.sqrt(n) + Y = Y + samples * nu + Q, R = jnp.linalg.qr(Y) + + # compute and symmetrize H, then take cholesky factor + H = samples.T @ Y + C = jnp.linalg.cholesky(0.5 * (H + H.T)).T + B = jsp.linalg.solve_triangular(C.T, R.T, lower=True).T + + # if improved == True + Qs, Rs = jnp.linalg.qr(samples) + Ws = Qs.T @ samples + + # solve and rescale + if self.improved: + S = jnp.linalg.inv(Rs).T + s = norm(S, axis=0) + S = S / s + scale = _get_scale(Ws, jnp.sum(S * Ws, axis=0), n, k) + else: + scale = 1 + + W = Q.T @ samples + S = jsp.linalg.solve_triangular(C, B.T).T / jnp.sqrt(jnp.diag(jnp.linalg.inv(H))) + dSW = jnp.sum(S * W, axis=0) + + estimates = norm(B, "fro") ** 2 - norm(S, axis=0) ** 2 + (jnp.abs(dSW) ** 2) * scale - nu * n + trace_est = jnp.mean(estimates) + std_err = jnp.std(estimates) / jnp.sqrt(k) + trace_est = jnp.where(is_nsd, -trace_est, trace_est) + + return trace_est, {"std.err": std_err} + + def transpose(self, state: _PSDTraceState) -> _PSDTraceState: + key, operator, n, is_nsd = state + return key, operator.transpose(), n, is_nsd + + +XNysTraceEstimator.__init__.__doc__ = r"""**Arguments:** + +- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. +- `improved`: whether to use the *improved* XNysTrace estimator, which rescales predicted samples. + Default is `True` (see Notes). +""" + + +def _estimate_trace_impl(key, operator, state, k, estimator, *, check_closure): + out = estimator.estimate(state, k) + if check_closure: + out = eqxi.nontraceable(out, name="`traceax.trace` with respect to a closed-over value") + result, stats = out + + return result, stats + + +_to_struct_tr = ft.partial(_to_struct, name="traceax.trace") + + +@eqxi.filter_primitive_def +def _estimate_trace_abstract_eval(key, operator, state, k, estimator): + key, state, k, estimator = jtu.tree_map(_to_struct_tr, (key, state, k, estimator)) + out = eqx.filter_eval_shape( + _estimate_trace_impl, + key, + operator, + state, + k, + estimator, + check_closure=False, + ) + out = jtu.tree_map(_to_shapedarray, out) + + return out + + +@ft.singledispatch +def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + raise ValueError("Unsupported type!") + + +@_make_identity.register +def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + in_size, out_size = eqx.filter_eval_shape(lambda o: (o.in_size(), o.out_size()), operator_struct) + if in_size != out_size: + raise ValueError("`_make_identity` only supports square matrices.") + diag = jnp.full(in_size, ct_result) + out = lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) + return out + + +@_make_identity.register +def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + scalar = jnp.array(1.0) + return lx.MulLinearOperator(inner_op, scalar*ct_result) + + +@_make_identity.register +def _(op: lx.TangentLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + p_op = op.primal + t_op = op.tangent + return lx.TangentLinearOperator(p_op, _make_identity(t_op, ct_result)) + + +@_make_identity.register +def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = jnp.full(in_size, ct_result) + return lx.DiagonalLinearOperator(diag) + + +@_make_identity.register +def _(op: lx.TridiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = jnp.full(in_size, ct_result) + off_diag = jnp.zeros(in_size - 1) + return lx.TridiagonalLinearOperator(diag, off_diag, off_diag) + + +@_make_identity.register +def _(op: lx.AddLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op1 = _make_identity(op.operator1, ct_result) + inner_op2 = _make_identity(op.operator2, ct_result) + return lx.AddLinearOperator(inner_op1, inner_op2) + + +@_make_identity.register +def _(op: lx.NegLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + return lx.NegLinearOperator(inner_op) + + +@_make_identity.register +def _(op: lx.DivLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + scalar = op.scalar + return lx.DivLinearOperator(inner_op, scalar) + + +@_make_identity.register +def _(op: lx.ComposedLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op1 = _make_identity(op.operator1, ct_result) + inner_op2 = _make_identity(op.operator2, ct_result) + return lx.ComposedLinearOperator(inner_op1, inner_op2) + + +@eqxi.filter_primitive_jvp +def _estimate_trace_jvp(primals, tangents): + key, operator, state, k, estimator = primals + # t_operator := V + t_key, t_operator, t_state, t_k, t_estimator = tangents + jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator)) + del t_key, t_state, t_k, t_estimator + + # primal problem of t = tr(A) + result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator) + out = result, stats + + # inner prodct in linear operator space => = tr(A @ B) + # d tr(A) / dA = I + # t' = = tr(I @ V) = tr(V) + # tangent problem => tr(V) + # TODO: should we reuse key or split? both seem confusing options + key, t_key = rdm.split(key) + + t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) + t_operator = lx.TangentLinearOperator(operator, t_operator) + + t_state = estimator.init(t_key, t_operator) + t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) + # t_result = jnp.trace(t_operator.as_matrix()) + t_out = ( + t_result, + jtu.tree_map(lambda _: None, stats), + ) + + return out, t_out + + +@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore +def _estimate_trace_transpose(inputs, cts_out): + # the jacobian, for the trace is just the identity matrix, i.e. J = I + # so J'v = I v = v + + # primal inputs; operator should have UndefinedPrimal leaves + key, operator, state, _, estimator = inputs + + # co-tangent of the trace approximation and the stats (None) + cts_result, _ = cts_out + + # the internals of the operator are UndefinedPrimal leaves so + # we need to rely on abstract values to pull structure info + op_t = _make_identity(operator, cts_result) + + key_none = jtu.tree_map(lambda _: None, key) + state_none = (None, op_t, None) + k_none = None + estimator_none = jtu.tree_map(lambda _: None, estimator) + + return key_none, op_t, state_none, k_none, estimator_none + + +_estimate_trace_p = eqxi.create_vprim( + "trace", + eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=False)), + _estimate_trace_abstract_eval, + _estimate_trace_jvp, + _estimate_trace_transpose, +) +# rebind here to allow closure checks +_estimate_trace_p.def_impl( + eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=True)), +) +eqxi.register_impl_finalisation(_estimate_trace_p) + + +@eqx.filter_jit +def trace( + key: PRNGKeyArray, + operator: lx.AbstractLinearOperator, + k: int, + estimator: AbstractEstimator = XTraceEstimator(), + *, + state: PyTree[Any] = sentinel, +) -> Solution: + r""" """ + if eqx.is_array(operator): + raise ValueError( + "`traceax.trace(..., operator=...)` should be an " + "`lineax.AbstractLinearOperator`, not a raw JAX array. If you are trying to pass " + "a matrix then this should be passed as " + "`lineax.MatrixLinearOperator(matrix)`." + ) + + in_size = operator.in_size() + out_size = operator.out_size() + if in_size != out_size: + raise ValueError( + "`traceax.trace(..., operator=...)` should be a square `lineax.AbstractLinearOperator`. " + f"Found shape {out_size}x{in_size}." + ) + + # if identity op, then just shortcircuit and return dimension size + if isinstance(operator, lx.IdentityLinearOperator): + return Solution( + value=jnp.asarray(in_size, dtype=float), + stats={}, + state=state, + ) + # if diagonal op, then just shortcircuit and sum diagonal + if isinstance(operator, lx.DiagonalLinearOperator): + return Solution( + value=jnp.sum(operator.diagonal), + stats={}, + state=state, + ) + + # set up state if necessary + if state == sentinel: + state = estimator.init(key, operator) + # we don't want to allow differntiate through trace-alg state, which likely contains the operator + # or by-products of the operator + dynamic_state, static_state = eqx.partition(state, eqx.is_array) + dynamic_state = lax.stop_gradient(dynamic_state) + state = eqx.combine(dynamic_state, static_state) + + # cannot differentiate through key, state, or estimator + key = eqxi.nondifferentiable(key, name="`trace(key, ...)`") + state = eqxi.nondifferentiable(state, name="`trace(..., state=...)`") + estimator = eqxi.nondifferentiable(estimator, name="`trace(..., estimator=...)`") + + # estimate trace and compute stats if any + result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator) + + # cannot differentiate backwards through stats + stats = eqxi.nondifferentiable_backward(stats, name="_, stats = trace(...)") + + return Solution(value=result, stats=stats, state=state) diff --git a/src/traceax/_utils.py b/src/traceax/_utils.py new file mode 100644 index 0000000..b5e50b3 --- /dev/null +++ b/src/traceax/_utils.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 MancusoLab. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import equinox.internal as eqxi +import jax +import jax.core + +from jax import vmap +from jax.interpreters import ad as ad +from lineax import AbstractLinearOperator + + +sentinel: Any = eqxi.doc_repr(object(), "sentinel") + + +def _check_operator(operator: AbstractLinearOperator) -> int: + n_in = operator.in_size() + n_out = operator.out_size() + if n_in != n_out: + raise ValueError(f"Estimation requires square linear operator. Found {(n_out, n_in)}.") + + return n_in + + +def _clip_k(k: int, n: int) -> int: + return min(max(k, 1), n) + + +def _vmap_mv(operator: AbstractLinearOperator): + return vmap(operator.mv, (1,), 1) + + +def _is_none(x): + return x is None + + +def _to_shapedarray(x): + if isinstance(x, jax.ShapeDtypeStruct): + return jax.core.ShapedArray(x.shape, x.dtype) + else: + return x + + +def _to_struct(x, name): + if isinstance(x, jax.core.ShapedArray): + return jax.ShapeDtypeStruct(x.shape, x.dtype) + elif isinstance(x, jax.core.AbstractValue): + raise NotImplementedError( + f"`{name}` only supports working with JAX arrays; not " f"other abstract values. Got abstract value {x}." + ) + else: + return x + + +def _assert_false(x): + assert False + + +def _is_undefined(x): + return isinstance(x, ad.UndefinedPrimal) + + +def _assert_defined(x): + assert not _is_undefined(x) + + +def _keep_undefined(v, ct): + if _is_undefined(v): + return ct + else: + return None + + +def _remove_undefined_primal(x): + if _is_undefined(x): + return x.aval + else: + return x \ No newline at end of file diff --git a/tests/test_trace.py b/tests/test_trace.py index 1eef35e..d2d6e21 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -14,17 +14,21 @@ import pytest +import jax import jax.numpy as jnp import lineax as lx -import traceax as tr +import traceax as tx from .helpers import ( construct_matrix, ) -@pytest.mark.parametrize("estimator", (tr.HutchinsonEstimator(), tr.HutchPlusPlusEstimator(), tr.XTraceEstimator())) +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) @pytest.mark.parametrize("k", (5, 10, 50)) @pytest.mark.parametrize( "tags", @@ -44,14 +48,17 @@ def test_matrix_linop(getkey, estimator, k, tags, size, dtype): k = min(k, size) matrix = construct_matrix(getkey, tags, size, dtype) operator = lx.MatrixLinearOperator(matrix, tags=tags) - result = estimator.estimate(getkey(), operator, k) + result = tx.trace(getkey(), operator, k, estimator) assert result is not None - assert result[0] is not None - assert jnp.isfinite(result[0]) + assert result.value is not None + assert jnp.isfinite(result.value) -@pytest.mark.parametrize("estimator", (tr.HutchinsonEstimator(), tr.HutchPlusPlusEstimator(), tr.XTraceEstimator())) +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) @pytest.mark.parametrize("k", (5, 10, 50)) @pytest.mark.parametrize("size", (5, 50, 500)) @pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) @@ -59,15 +66,21 @@ def test_diag_linop(getkey, estimator, k, size, dtype): k = min(k, size) matrix = construct_matrix(getkey, lx.diagonal_tag, size, dtype) operator = lx.DiagonalLinearOperator(jnp.diag(matrix)) - result = estimator.estimate(getkey(), operator, k) + result = tx.trace(getkey(), operator, k, estimator) assert result is not None - assert result[0] is not None - assert jnp.isfinite(result[0]) + assert result.value is not None + assert jnp.isfinite(result.value) @pytest.mark.parametrize( - "estimator", (tr.HutchinsonEstimator(), tr.HutchPlusPlusEstimator(), tr.XTraceEstimator(), tr.XNysTraceEstimator()) + "estimator", + ( + tx.HutchinsonEstimator(), + tx.HutchPlusPlusEstimator(), + tx.XTraceEstimator(), + tx.XNysTraceEstimator(), + ), ) @pytest.mark.parametrize("k", (5, 10, 50)) @pytest.mark.parametrize("tags", (lx.positive_semidefinite_tag, lx.negative_semidefinite_tag)) @@ -77,8 +90,157 @@ def test_nsd_psd_matrix_linop(getkey, estimator, k, tags, size, dtype): k = min(k, size) matrix = construct_matrix(getkey, tags, size, dtype) operator = lx.MatrixLinearOperator(matrix, tags=tags) - result = estimator.estimate(getkey(), operator, k) + result = tx.trace(getkey(), operator, k, estimator) assert result is not None - assert result[0] is not None - assert jnp.isfinite(result[0]) + assert result.value is not None + assert jnp.isfinite(result.value) + + +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) +@pytest.mark.parametrize("k", (5, 10, 50)) +@pytest.mark.parametrize("size", (5, 50, 500)) +@pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) +def test_tridiagonal_linop(getkey, estimator, k, size, dtype): + k = min(k, size) + matrix = construct_matrix(getkey, lx.tridiagonal_tag, size, dtype) + main_diag = jnp.diag(matrix) + lower_diag = jnp.diag(matrix, k=-1) + upper_diag = jnp.diag(matrix, k=1) + operator = lx.TridiagonalLinearOperator(main_diag, lower_diag, upper_diag) + result = tx.trace(getkey(), operator, k, estimator) + + assert result is not None + assert result.value is not None + assert jnp.isfinite(result.value) + + +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) +@pytest.mark.parametrize("k", (5, 10, 50)) +@pytest.mark.parametrize("size", (5, 50, 500)) +@pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) +def test_identity_linop(getkey, estimator, k, size, dtype): + k = min(k, size) + input_structure = jax.ShapeDtypeStruct((size,), dtype) + operator = lx.IdentityLinearOperator(input_structure) + result = tx.trace(getkey(), operator, k, estimator) + + assert result is not None + assert result.value is not None + assert jnp.isfinite(result.value) + + +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) +@pytest.mark.parametrize("k", (5, 10, 50)) +@pytest.mark.parametrize( + "tags", + ( + lx.diagonal_tag, + lx.symmetric_tag, + lx.lower_triangular_tag, + lx.upper_triangular_tag, + lx.tridiagonal_tag, + lx.unit_diagonal_tag, + lx.positive_semidefinite_tag, + lx.negative_semidefinite_tag, + ), +) +@pytest.mark.parametrize("size", (5, 50, 500)) +@pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) +def test_tagged_linop(getkey, estimator, k, tags, size, dtype): + k = min(k, size) + matrix = construct_matrix(getkey, tags, size, dtype) + operator = lx.MatrixLinearOperator(matrix, tags=tags) + tagged_operator = lx.TaggedLinearOperator(operator, tags=tags) + result = tx.trace(getkey(), tagged_operator, k, estimator) + + assert result is not None + assert result.value is not None + assert jnp.isfinite(result.value) + + sym_operator = operator + operator.T + sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag) + sym_result = tx.trace(getkey(), sym_operator, k, estimator) + + assert lx.is_symmetric(sym_operator) + assert sym_result is not None + assert sym_result.value is not None + assert jnp.isfinite(sym_result.value) + + +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) +@pytest.mark.parametrize( + "tags", + ( + None, + lx.diagonal_tag, + lx.symmetric_tag, + lx.lower_triangular_tag, + lx.upper_triangular_tag, + lx.tridiagonal_tag, + lx.unit_diagonal_tag, + lx.positive_semidefinite_tag, + lx.negative_semidefinite_tag, + ), +) +@pytest.mark.parametrize("k", (5, 10, 50)) +@pytest.mark.parametrize("size", (5, 50, 500)) +@pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) +def test_compound_op(getkey, estimator, k, tags, size, dtype): + k = min(k, size) + matrix_a = construct_matrix(getkey, tags, size, dtype) + matrix_b = construct_matrix(getkey, tags, size, dtype) + op_a = lx.MatrixLinearOperator(matrix_a, tags=tags) + op_b = lx.MatrixLinearOperator(matrix_b, tags=tags) + + """AddLinearOperator""" + add_op = op_a + op_b + add_result = tx.trace(getkey(), add_op, k, estimator) + + assert add_result is not None + assert add_result.value is not None + assert jnp.isfinite(add_result.value) + + """ComposedLinearOperator""" + composed_op = op_a @ op_b + composed_result = tx.trace(getkey(), composed_op, k, estimator) + + assert composed_result is not None + assert composed_result.value is not None + assert jnp.isfinite(composed_result.value) + + """MulLinearOperator""" + scalar = jnp.asarray(0.5, dtype=dtype) # random value; make sure precision matches specified! + mul_op = scalar * op_a + mul_result = tx.trace(getkey(), mul_op, k, estimator) # pyright: ignore + + assert mul_result is not None + assert mul_result.value is not None + assert jnp.isfinite(mul_result.value) + + """NegLinearOperator""" + neg_result = tx.trace(getkey(), -op_a, k, estimator) + + assert neg_result is not None + assert neg_result.value is not None + assert jnp.isfinite(neg_result.value) + + """DivLinearOperator""" + denom = jnp.asarray(0.5, dtype=dtype) # random value; make sure precision matches specified! + div_result = tx.trace(getkey(), op_b / denom, k, estimator) + + assert div_result is not None + assert div_result.value is not None + assert jnp.isfinite(div_result.value)