diff --git a/filter_functions/basis.py b/filter_functions/basis.py index d4df397..5cf6942 100644 --- a/filter_functions/basis.py +++ b/filter_functions/basis.py @@ -237,6 +237,13 @@ def _print_checks(self) -> None: for check in checks: print(check, ':\t', getattr(self, check)) + def _invalidate_cached_properties(self): + for attr in {'isherm', 'isnorm', 'isorthogonal', 'istraceless', 'iscomplete'}: + try: + delattr(self, attr) + except AttributeError: + pass + @cached_property def isherm(self) -> bool: """Returns True if all basis elements are hermitian.""" @@ -248,22 +255,25 @@ def isnorm(self) -> bool: return self.normalize(copy=True) == self @cached_property - def isorthonorm(self) -> bool: - """Returns True if basis is orthonormal.""" - # All the basis is orthonormal iff the matrix consisting of all - # d**2 elements written as d**2-dimensional column vectors is - # unitary. + def isorthogonal(self) -> bool: + """Returns True if all basis elements are mutually orthogonal.""" if self.ndim == 2 or len(self) == 1: - # Only one basis element return True - else: - # Size of the result after multiplication - dim = self.shape[0] - U = self.reshape((dim, -1)) - actual = U.conj() @ U.T - target = np.identity(dim) - atol = self._eps*(self.d**2)**3 - return np.allclose(actual.view(np.ndarray), target, atol=atol, rtol=self._rtol) + + # The basis is orthogonal iff the matrix consisting of all d**2 + # elements written as d**2-dimensional column vectors is + # orthogonal. + dim = self.shape[0] + U = self.reshape((dim, -1)) + actual = U.conj() @ U.T + atol = self._eps*(self.d**2)**3 + mask = np.identity(dim, dtype=bool) + return np.allclose(actual[..., ~mask].view(np.ndarray), 0, atol=atol, rtol=self._rtol) + + @property + def isorthonorm(self) -> bool: + """Returns True if basis is orthonormal.""" + return self.isorthogonal and self.isnorm @cached_property def istraceless(self) -> bool: @@ -366,6 +376,7 @@ def normalize(self, copy: bool = False) -> Union[None, 'Basis']: return normalize(self) self /= _norm(self) + self._invalidate_cached_properties() def tidyup(self, eps_scale: Optional[float] = None) -> None: """Wraps util.remove_float_errors.""" @@ -377,6 +388,8 @@ def tidyup(self, eps_scale: Optional[float] = None) -> None: self.real[np.abs(self.real) <= atol] = 0 self.imag[np.abs(self.imag) <= atol] = 0 + self._invalidate_cached_properties() + @classmethod def pauli(cls, n: int) -> 'Basis': r""" @@ -544,8 +557,8 @@ def _full_from_partial(elems: Sequence, traceless: bool, labels: Sequence[str]) if not elems.isherm: warn("(Some) elems not hermitian! The resulting basis also won't be.") - if not elems.isorthonorm: - raise ValueError("The basis elements are not orthonormal!") + if not elems.isorthogonal: + raise ValueError("The basis elements are not orthogonal!") if traceless is None: traceless = elems.istraceless @@ -631,7 +644,7 @@ def normalize(b: Basis) -> Basis: Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 """ - return (b/_norm(b)).squeeze().view(Basis) + return (b/_norm(b)).squeeze().reshape(b.shape).view(Basis) def expand(M: Union[np.ndarray, Basis], basis: Union[np.ndarray, Basis], diff --git a/tests/test_basis.py b/tests/test_basis.py index f26796b..a265d79 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -24,7 +24,6 @@ from copy import copy from itertools import product -import filter_functions as ff import numpy as np import pytest from opt_einsum import contract @@ -109,6 +108,7 @@ def test_basis_properties(self): pauli_basis = ff.Basis.pauli(n) from_partial_basis = ff.Basis.from_partial(testutil.rand_herm(d), traceless=False) custom_basis = ff.Basis(testutil.rand_herm_traceless(d)) + custom_basis /= np.linalg.norm(custom_basis, ord='fro', axis=(-1, -2)) btypes = ('Pauli', 'GGM', 'From partial', 'Custom') bases = (pauli_basis, ggm_basis, from_partial_basis, custom_basis) @@ -127,6 +127,8 @@ def test_basis_properties(self): self.assertTrue(base[rng.integers(0, len(base))] in base) # Check if all elements of each basis are orthonormal and hermitian self.assertArrayEqual(base.T, base.view(np.ndarray).swapaxes(-1, -2)) + self.assertTrue(base.isnorm) + self.assertTrue(base.isorthogonal) self.assertTrue(base.isorthonorm) self.assertTrue(base.isherm) # Check if basis spans the whole space and all elems are traceless @@ -137,6 +139,9 @@ def test_basis_properties(self): if not btype == 'Custom': self.assertTrue(base.iscomplete) + else: + self.assertFalse(base.iscomplete) + # Check sparse representation self.assertArrayEqual(base.sparse.todense(), base) # Test sparse cache @@ -154,9 +159,17 @@ def test_basis_properties(self): base._print_checks() - # single element always considered orthonormal - orthonorm = rng.normal(size=(d, d)) - self.assertTrue(orthonorm.view(ff.Basis).isorthonorm) + # single element always considered orthogonal + orthogonal = rng.normal(size=(d, d)) + self.assertTrue(orthogonal.view(ff.Basis).isorthogonal) + + orthogonal /= np.linalg.norm(orthogonal, 'fro', axis=(-1, -2)) + self.assertTrue(orthogonal.view(ff.Basis).isorthonorm) + + nonorthogonal = rng.normal(size=(3, d, d)).view(ff.Basis) + self.assertFalse(nonorthogonal.isnorm) + nonorthogonal.normalize() + self.assertTrue(nonorthogonal.isnorm) herm = testutil.rand_herm(d).squeeze() self.assertTrue(herm.view(ff.Basis).isherm) @@ -206,6 +219,22 @@ def test_basis_expansion_and_normalization(self): r = ff.basis.ggm_expand(rng.standard_normal((3, 3)), hermitian=False) self.assertTrue(r.dtype == 'complex128') + # test the method + pauli = ff.Basis.pauli(1) + ggm = ff.Basis.ggm(2) + + M = testutil.rand_herm(2, 3) + self.assertArrayAlmostEqual(pauli.expand(M, hermitian=True, traceless=False, tidyup=True), + ggm.expand(M, hermitian=True, traceless=False, tidyup=True)) + + M = testutil.rand_herm_traceless(2, 3) + self.assertArrayAlmostEqual(pauli.expand(M, hermitian=True, traceless=True, tidyup=True), + ggm.expand(M, hermitian=True, traceless=True, tidyup=True)) + + M = testutil.rand_unit(2, 3) + self.assertArrayAlmostEqual(pauli.expand(M, hermitian=False, traceless=False, tidyup=True), + ggm.expand(M, hermitian=False, traceless=False, tidyup=True)) + for _ in range(10): d = rng.integers(2, 16) ggm_basis = ff.Basis.ggm(d)