Skip to content

Commit

Permalink
Streamline properties
Browse files Browse the repository at this point in the history
  • Loading branch information
thangleiter committed Nov 7, 2024
1 parent a52b20b commit 22107eb
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
47 changes: 30 additions & 17 deletions filter_functions/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ def _print_checks(self) -> None:
for check in checks:
print(check, ':\t', getattr(self, check))

def _invalidate_cached_properties(self):
for attr in {'isherm', 'isnorm', 'isorthogonal', 'istraceless', 'iscomplete'}:
try:
delattr(self, attr)
except AttributeError:
pass

@cached_property
def isherm(self) -> bool:
"""Returns True if all basis elements are hermitian."""
Expand All @@ -248,22 +255,25 @@ def isnorm(self) -> bool:
return self.normalize(copy=True) == self

@cached_property
def isorthonorm(self) -> bool:
"""Returns True if basis is orthonormal."""
# All the basis is orthonormal iff the matrix consisting of all
# d**2 elements written as d**2-dimensional column vectors is
# unitary.
def isorthogonal(self) -> bool:
"""Returns True if all basis elements are mutually orthogonal."""
if self.ndim == 2 or len(self) == 1:
# Only one basis element
return True
else:
# Size of the result after multiplication
dim = self.shape[0]
U = self.reshape((dim, -1))
actual = U.conj() @ U.T
target = np.identity(dim)
atol = self._eps*(self.d**2)**3
return np.allclose(actual.view(np.ndarray), target, atol=atol, rtol=self._rtol)

# The basis is orthogonal iff the matrix consisting of all d**2
# elements written as d**2-dimensional column vectors is
# orthogonal.
dim = self.shape[0]
U = self.reshape((dim, -1))
actual = U.conj() @ U.T
atol = self._eps*(self.d**2)**3
mask = np.identity(dim, dtype=bool)
return np.allclose(actual[..., ~mask].view(np.ndarray), 0, atol=atol, rtol=self._rtol)

@property
def isorthonorm(self) -> bool:
"""Returns True if basis is orthonormal."""
return self.isorthogonal and self.isnorm

@cached_property
def istraceless(self) -> bool:
Expand Down Expand Up @@ -366,6 +376,7 @@ def normalize(self, copy: bool = False) -> Union[None, 'Basis']:
return normalize(self)

self /= _norm(self)
self._invalidate_cached_properties()

def tidyup(self, eps_scale: Optional[float] = None) -> None:
"""Wraps util.remove_float_errors."""
Expand All @@ -377,6 +388,8 @@ def tidyup(self, eps_scale: Optional[float] = None) -> None:
self.real[np.abs(self.real) <= atol] = 0
self.imag[np.abs(self.imag) <= atol] = 0

self._invalidate_cached_properties()

@classmethod
def pauli(cls, n: int) -> 'Basis':
r"""
Expand Down Expand Up @@ -544,8 +557,8 @@ def _full_from_partial(elems: Sequence, traceless: bool, labels: Sequence[str])
if not elems.isherm:
warn("(Some) elems not hermitian! The resulting basis also won't be.")

if not elems.isorthonorm:
raise ValueError("The basis elements are not orthonormal!")
if not elems.isorthogonal:
raise ValueError("The basis elements are not orthogonal!")

if traceless is None:
traceless = elems.istraceless
Expand Down Expand Up @@ -631,7 +644,7 @@ def normalize(b: Basis) -> Basis:
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
"""
return (b/_norm(b)).squeeze().view(Basis)
return (b/_norm(b)).squeeze().reshape(b.shape).view(Basis)


def expand(M: Union[np.ndarray, Basis], basis: Union[np.ndarray, Basis],
Expand Down
37 changes: 33 additions & 4 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from copy import copy
from itertools import product

import filter_functions as ff
import numpy as np
import pytest
from opt_einsum import contract
Expand Down Expand Up @@ -109,6 +108,7 @@ def test_basis_properties(self):
pauli_basis = ff.Basis.pauli(n)
from_partial_basis = ff.Basis.from_partial(testutil.rand_herm(d), traceless=False)
custom_basis = ff.Basis(testutil.rand_herm_traceless(d))
custom_basis /= np.linalg.norm(custom_basis, ord='fro', axis=(-1, -2))

btypes = ('Pauli', 'GGM', 'From partial', 'Custom')
bases = (pauli_basis, ggm_basis, from_partial_basis, custom_basis)
Expand All @@ -127,6 +127,8 @@ def test_basis_properties(self):
self.assertTrue(base[rng.integers(0, len(base))] in base)
# Check if all elements of each basis are orthonormal and hermitian
self.assertArrayEqual(base.T, base.view(np.ndarray).swapaxes(-1, -2))
self.assertTrue(base.isnorm)
self.assertTrue(base.isorthogonal)
self.assertTrue(base.isorthonorm)
self.assertTrue(base.isherm)
# Check if basis spans the whole space and all elems are traceless
Expand All @@ -137,6 +139,9 @@ def test_basis_properties(self):

if not btype == 'Custom':
self.assertTrue(base.iscomplete)
else:
self.assertFalse(base.iscomplete)

# Check sparse representation
self.assertArrayEqual(base.sparse.todense(), base)
# Test sparse cache
Expand All @@ -154,9 +159,17 @@ def test_basis_properties(self):

base._print_checks()

# single element always considered orthonormal
orthonorm = rng.normal(size=(d, d))
self.assertTrue(orthonorm.view(ff.Basis).isorthonorm)
# single element always considered orthogonal
orthogonal = rng.normal(size=(d, d))
self.assertTrue(orthogonal.view(ff.Basis).isorthogonal)

orthogonal /= np.linalg.norm(orthogonal, 'fro', axis=(-1, -2))
self.assertTrue(orthogonal.view(ff.Basis).isorthonorm)

nonorthogonal = rng.normal(size=(3, d, d)).view(ff.Basis)
self.assertFalse(nonorthogonal.isnorm)
nonorthogonal.normalize()
self.assertTrue(nonorthogonal.isnorm)

herm = testutil.rand_herm(d).squeeze()
self.assertTrue(herm.view(ff.Basis).isherm)
Expand Down Expand Up @@ -206,6 +219,22 @@ def test_basis_expansion_and_normalization(self):
r = ff.basis.ggm_expand(rng.standard_normal((3, 3)), hermitian=False)
self.assertTrue(r.dtype == 'complex128')

# test the method
pauli = ff.Basis.pauli(1)
ggm = ff.Basis.ggm(2)

M = testutil.rand_herm(2, 3)
self.assertArrayAlmostEqual(pauli.expand(M, hermitian=True, traceless=False, tidyup=True),
ggm.expand(M, hermitian=True, traceless=False, tidyup=True))

M = testutil.rand_herm_traceless(2, 3)
self.assertArrayAlmostEqual(pauli.expand(M, hermitian=True, traceless=True, tidyup=True),
ggm.expand(M, hermitian=True, traceless=True, tidyup=True))

M = testutil.rand_unit(2, 3)
self.assertArrayAlmostEqual(pauli.expand(M, hermitian=False, traceless=False, tidyup=True),
ggm.expand(M, hermitian=False, traceless=False, tidyup=True))

for _ in range(10):
d = rng.integers(2, 16)
ggm_basis = ff.Basis.ggm(d)
Expand Down

0 comments on commit 22107eb

Please sign in to comment.