diff --git a/quimb/tensor/decomp.py b/quimb/tensor/decomp.py index 2b8d03cf..52d40666 100644 --- a/quimb/tensor/decomp.py +++ b/quimb/tensor/decomp.py @@ -25,6 +25,21 @@ from ..linalg import rand_linalg +_CUTOFF_MODE_MAP = { + "abs": 1, + "rel": 2, + "sum2": 3, + "rsum2": 4, + "sum1": 5, + "rsum1": 6, +} + + +def map_cutoff_mode(cutoff_mode): + """Map mode to an integer for compatibility with numba.""" + return _CUTOFF_MODE_MAP.get(cutoff_mode, cutoff_mode) + + # some convenience functions for multiplying diagonals @@ -82,14 +97,14 @@ def sgn(x): """Get the 'sign' of ``x``, such that ``x / sgn(x)`` is real and non-negative. """ - x0 = (x == 0.0) + x0 = x == 0.0 return (x + x0) / (do("abs", x) + x0) @sgn.register("numpy") @njit # pragma: no cover def sgn_numba(x): - x0 = (x == 0.0) + x0 = x == 0.0 return (x + x0) / (np.abs(x) + x0) @@ -176,10 +191,10 @@ def svd_truncated( Parameters ---------- - cutoff : float + cutoff : float, optional Singular value cutoff threshold, if ``cutoff <= 0.0``, then only ``max_bond`` is used. - cutoff_mode : {1, 2, 3, 4, 5, 6} + cutoff_mode : {1, 2, 3, 4, 5, 6}, optional How to perform the trim: - 1: ['abs'], trim values below ``cutoff`` @@ -189,12 +204,12 @@ def svd_truncated( - 5: ['sum1'], trim s.t. ``sum(s_trim**1) < cutoff``. - 6: ['rsum1'], trim s.t. ``sum(s_trim**1) < sum(s**1) * cutoff``. - max_bond : int + max_bond : int, optional An explicit maximum bond dimension, use -1 for none. - absorb : {-1, 0, 1, None} + absorb : {-1, 0, 1, None}, optional How to absorb the singular values. -1: left, 0: both, 1: right and None: don't absorb (return). - renorm : {0, 1} + renorm : {0, 1}, optional Whether to renormalize the singular values (depends on `cutoff_mode`). """ with backend_like(backend): @@ -313,7 +328,12 @@ def svd_truncated_numba( @svd_truncated.register("autoray.lazy") @lazy.core.lazy_cache("svd_truncated") def svd_truncated_lazy( - x, cutoff=-1.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0, + x, + cutoff=-1.0, + cutoff_mode=4, + max_bond=-1, + absorb=0, + renorm=0, ): if cutoff != 0.0: raise ValueError("Can't handle dynamic cutoffs in lazy mode.") @@ -326,7 +346,7 @@ def svd_truncated_lazy( lsvdt = x.to( fn=get_lib_fn(x.backend, "svd_truncated"), args=(x, cutoff, cutoff_mode, max_bond, absorb, renorm), - shape=(3,) + shape=(3,), ) U = lsvdt.to(operator.getitem, (lsvdt, 0), shape=(m, k)) @@ -364,14 +384,14 @@ def lu_truncated( ) with backend_like(backend): - PL, U = do('scipy.linalg.lu', x, permute_l=True) + PL, U = do("scipy.linalg.lu", x, permute_l=True) - sl = do('sum', do('abs', PL), axis=0) - su = do('sum', do('abs', U), axis=1) + sl = do("sum", do("abs", PL), axis=0) + su = do("sum", do("abs", U), axis=1) if cutoff_mode == 2: - abs_cutoff_l = cutoff * do('max', sl) - abs_cutoff_u = cutoff * do('max', su) + abs_cutoff_l = cutoff * do("max", sl) + abs_cutoff_u = cutoff * do("max", su) elif cutoff_mode == 1: abs_cutoff_l = abs_cutoff_u = cutoff else: @@ -943,13 +963,13 @@ def isometrize_cayley(x, backend): "pad", x, [[0, d - m], [0, d - n]], "constant", constant_values=0.0 ) x = x - dag(x) - x = x / 2. + x = x / 2.0 if backend == "torch": # XXX: move device handling upstream in to autoray? Id = do("eye", d, like=x, device=x.device) else: Id = do("eye", d, like=x) - Q = do('linalg.solve', Id - x, Id + x) + Q = do("linalg.solve", Id - x, Id + x) return Q[:m, :n] @@ -974,7 +994,7 @@ def isometrize_modified_gram_schmidt(A, backend=None): def isometrize_householder(X, backend=None): with backend_like(backend): X = do("tril", X, -1) - tau = 2. / (1. + do("sum", do("conj", X) * X, 0)) + tau = 2.0 / (1.0 + do("sum", do("conj", X) * X, 0)) Q = do("linalg.householder_product", X, tau) return Q @@ -1125,7 +1145,7 @@ def squared_op_to_reduced_factor_numba(x2, dl, dr, right=True): def compute_oblique_projectors( - Rl, Rr, max_bond, cutoff, absorb="both", **compress_opts + Rl, Rr, max_bond, cutoff, absorb="both", cutoff_mode=4, **compress_opts ): """Compute the oblique projectors for two reduced factor matrices that describe a gauge on a bond. Concretely, assuming that ``Rl`` and ``Rr`` are @@ -1162,12 +1182,15 @@ def compute_oblique_projectors( if max_bond is None: max_bond = -1 + cutoff_mode = map_cutoff_mode(cutoff_mode) + Ut, st, VHt = svd_truncated( Rl @ Rr, max_bond=max_bond, cutoff=cutoff, absorb=None, - **compress_opts + cutoff_mode=cutoff_mode, + **compress_opts, ) st_sqrt = do("sqrt", st)