Skip to content

Commit 22107eb

Browse files
committed
Streamline properties
1 parent a52b20b commit 22107eb

File tree

2 files changed

+63
-21
lines changed

2 files changed

+63
-21
lines changed

filter_functions/basis.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,13 @@ def _print_checks(self) -> None:
237237
for check in checks:
238238
print(check, ':\t', getattr(self, check))
239239

240+
def _invalidate_cached_properties(self):
241+
for attr in {'isherm', 'isnorm', 'isorthogonal', 'istraceless', 'iscomplete'}:
242+
try:
243+
delattr(self, attr)
244+
except AttributeError:
245+
pass
246+
240247
@cached_property
241248
def isherm(self) -> bool:
242249
"""Returns True if all basis elements are hermitian."""
@@ -248,22 +255,25 @@ def isnorm(self) -> bool:
248255
return self.normalize(copy=True) == self
249256

250257
@cached_property
251-
def isorthonorm(self) -> bool:
252-
"""Returns True if basis is orthonormal."""
253-
# All the basis is orthonormal iff the matrix consisting of all
254-
# d**2 elements written as d**2-dimensional column vectors is
255-
# unitary.
258+
def isorthogonal(self) -> bool:
259+
"""Returns True if all basis elements are mutually orthogonal."""
256260
if self.ndim == 2 or len(self) == 1:
257-
# Only one basis element
258261
return True
259-
else:
260-
# Size of the result after multiplication
261-
dim = self.shape[0]
262-
U = self.reshape((dim, -1))
263-
actual = U.conj() @ U.T
264-
target = np.identity(dim)
265-
atol = self._eps*(self.d**2)**3
266-
return np.allclose(actual.view(np.ndarray), target, atol=atol, rtol=self._rtol)
262+
263+
# The basis is orthogonal iff the matrix consisting of all d**2
264+
# elements written as d**2-dimensional column vectors is
265+
# orthogonal.
266+
dim = self.shape[0]
267+
U = self.reshape((dim, -1))
268+
actual = U.conj() @ U.T
269+
atol = self._eps*(self.d**2)**3
270+
mask = np.identity(dim, dtype=bool)
271+
return np.allclose(actual[..., ~mask].view(np.ndarray), 0, atol=atol, rtol=self._rtol)
272+
273+
@property
274+
def isorthonorm(self) -> bool:
275+
"""Returns True if basis is orthonormal."""
276+
return self.isorthogonal and self.isnorm
267277

268278
@cached_property
269279
def istraceless(self) -> bool:
@@ -366,6 +376,7 @@ def normalize(self, copy: bool = False) -> Union[None, 'Basis']:
366376
return normalize(self)
367377

368378
self /= _norm(self)
379+
self._invalidate_cached_properties()
369380

370381
def tidyup(self, eps_scale: Optional[float] = None) -> None:
371382
"""Wraps util.remove_float_errors."""
@@ -377,6 +388,8 @@ def tidyup(self, eps_scale: Optional[float] = None) -> None:
377388
self.real[np.abs(self.real) <= atol] = 0
378389
self.imag[np.abs(self.imag) <= atol] = 0
379390

391+
self._invalidate_cached_properties()
392+
380393
@classmethod
381394
def pauli(cls, n: int) -> 'Basis':
382395
r"""
@@ -544,8 +557,8 @@ def _full_from_partial(elems: Sequence, traceless: bool, labels: Sequence[str])
544557
if not elems.isherm:
545558
warn("(Some) elems not hermitian! The resulting basis also won't be.")
546559

547-
if not elems.isorthonorm:
548-
raise ValueError("The basis elements are not orthonormal!")
560+
if not elems.isorthogonal:
561+
raise ValueError("The basis elements are not orthogonal!")
549562

550563
if traceless is None:
551564
traceless = elems.istraceless
@@ -631,7 +644,7 @@ def normalize(b: Basis) -> Basis:
631644
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
632645
633646
"""
634-
return (b/_norm(b)).squeeze().view(Basis)
647+
return (b/_norm(b)).squeeze().reshape(b.shape).view(Basis)
635648

636649

637650
def expand(M: Union[np.ndarray, Basis], basis: Union[np.ndarray, Basis],

tests/test_basis.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from copy import copy
2525
from itertools import product
2626

27-
import filter_functions as ff
2827
import numpy as np
2928
import pytest
3029
from opt_einsum import contract
@@ -109,6 +108,7 @@ def test_basis_properties(self):
109108
pauli_basis = ff.Basis.pauli(n)
110109
from_partial_basis = ff.Basis.from_partial(testutil.rand_herm(d), traceless=False)
111110
custom_basis = ff.Basis(testutil.rand_herm_traceless(d))
111+
custom_basis /= np.linalg.norm(custom_basis, ord='fro', axis=(-1, -2))
112112

113113
btypes = ('Pauli', 'GGM', 'From partial', 'Custom')
114114
bases = (pauli_basis, ggm_basis, from_partial_basis, custom_basis)
@@ -127,6 +127,8 @@ def test_basis_properties(self):
127127
self.assertTrue(base[rng.integers(0, len(base))] in base)
128128
# Check if all elements of each basis are orthonormal and hermitian
129129
self.assertArrayEqual(base.T, base.view(np.ndarray).swapaxes(-1, -2))
130+
self.assertTrue(base.isnorm)
131+
self.assertTrue(base.isorthogonal)
130132
self.assertTrue(base.isorthonorm)
131133
self.assertTrue(base.isherm)
132134
# Check if basis spans the whole space and all elems are traceless
@@ -137,6 +139,9 @@ def test_basis_properties(self):
137139

138140
if not btype == 'Custom':
139141
self.assertTrue(base.iscomplete)
142+
else:
143+
self.assertFalse(base.iscomplete)
144+
140145
# Check sparse representation
141146
self.assertArrayEqual(base.sparse.todense(), base)
142147
# Test sparse cache
@@ -154,9 +159,17 @@ def test_basis_properties(self):
154159

155160
base._print_checks()
156161

157-
# single element always considered orthonormal
158-
orthonorm = rng.normal(size=(d, d))
159-
self.assertTrue(orthonorm.view(ff.Basis).isorthonorm)
162+
# single element always considered orthogonal
163+
orthogonal = rng.normal(size=(d, d))
164+
self.assertTrue(orthogonal.view(ff.Basis).isorthogonal)
165+
166+
orthogonal /= np.linalg.norm(orthogonal, 'fro', axis=(-1, -2))
167+
self.assertTrue(orthogonal.view(ff.Basis).isorthonorm)
168+
169+
nonorthogonal = rng.normal(size=(3, d, d)).view(ff.Basis)
170+
self.assertFalse(nonorthogonal.isnorm)
171+
nonorthogonal.normalize()
172+
self.assertTrue(nonorthogonal.isnorm)
160173

161174
herm = testutil.rand_herm(d).squeeze()
162175
self.assertTrue(herm.view(ff.Basis).isherm)
@@ -206,6 +219,22 @@ def test_basis_expansion_and_normalization(self):
206219
r = ff.basis.ggm_expand(rng.standard_normal((3, 3)), hermitian=False)
207220
self.assertTrue(r.dtype == 'complex128')
208221

222+
# test the method
223+
pauli = ff.Basis.pauli(1)
224+
ggm = ff.Basis.ggm(2)
225+
226+
M = testutil.rand_herm(2, 3)
227+
self.assertArrayAlmostEqual(pauli.expand(M, hermitian=True, traceless=False, tidyup=True),
228+
ggm.expand(M, hermitian=True, traceless=False, tidyup=True))
229+
230+
M = testutil.rand_herm_traceless(2, 3)
231+
self.assertArrayAlmostEqual(pauli.expand(M, hermitian=True, traceless=True, tidyup=True),
232+
ggm.expand(M, hermitian=True, traceless=True, tidyup=True))
233+
234+
M = testutil.rand_unit(2, 3)
235+
self.assertArrayAlmostEqual(pauli.expand(M, hermitian=False, traceless=False, tidyup=True),
236+
ggm.expand(M, hermitian=False, traceless=False, tidyup=True))
237+
209238
for _ in range(10):
210239
d = rng.integers(2, 16)
211240
ggm_basis = ff.Basis.ggm(d)

0 commit comments

Comments
 (0)