diff --git a/.github/ACKNOWLEDGMENTS.md b/.github/ACKNOWLEDGMENTS.md index 6eec2777..833bfc3f 100644 --- a/.github/ACKNOWLEDGMENTS.md +++ b/.github/ACKNOWLEDGMENTS.md @@ -67,3 +67,5 @@ * [Martin Houde](https://github.com/MHoude2) (Polytechnique Montréal) - 🙃 Minister of amplification * Will McCutcheon (Heriot-Watt University) - 🧅 Gaussian Onion Merchant + +* [Yanic Cardin](https://github.com/yaniccd) (Polytechnique Montréal) - 🦜 Pirate of the permutations diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 4c0b3875..13db6a7f 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -1,19 +1,15 @@ -# Release 0.21.0-dev +# Release 0.22.0-dev ### New features -* Adds the Takagi decomposition [(#363)](https://github.com/XanaduAI/thewalrus/pull/338) - ### Breaking changes ### Improvements +* Further simplifies the implementation of `decompositions.williamson` and corrects its docstring [(#380)](https://github.com/XanaduAI/thewalrus/pull/380). -* Tighten power-trace bound of odd loop Hafnian. [(#362)](https://github.com/XanaduAI/thewalrus/pull/362) - -* Simplifies the internal working of Bloch-Messiah decomposition [(#363)](https://github.com/XanaduAI/thewalrus/pull/338). +* Further simplifies the implementation of `decompositions.blochmessiah` [(#381)](https://github.com/XanaduAI/thewalrus/pull/381). -* Simplifies the internal working of Williamson decomposition [(#366)](https://github.com/XanaduAI/thewalrus/pull/338). ### Bug fixes @@ -21,9 +17,37 @@ ### Contributors -This release contains contributions from (in alphabetical order): +This release contains contributions from (in alphabetical order): + +Nicolas Quesada + +--- + +# Release 0.21.0 + +### New features + +* Adds the Takagi decomposition [(#363)](https://github.com/XanaduAI/thewalrus/pull/363) + +* Adds the Montrealer and Loop Montrealer functions [(#363)](https://github.com/XanaduAI/thewalrus/pull/374). + +### Improvements + +* Tighten power-trace bound of odd loop Hafnian. [(#362)](https://github.com/XanaduAI/thewalrus/pull/362) + +* Simplifies the internal working of Bloch-Messiah decomposition [(#363)](https://github.com/XanaduAI/thewalrus/pull/338). + +* Simplifies the internal working of Williamson decomposition [(#366)](https://github.com/XanaduAI/thewalrus/pull/338). + +* Improves the handling of an edge case in Takagi [(#373)](https://github.com/XanaduAI/thewalrus/pull/373). + +* Adds extra tests for the Takagi decomposition [(#377)](https://github.com/XanaduAI/thewalrus/pull/377) + +### Contributors + +This release contains contributions from (in alphabetical order): -Gregory Morse, Nicolas Quesada +Yanic Cardin, Gregory Morse, Nicolas Quesada --- diff --git a/thewalrus/__init__.py b/thewalrus/__init__.py index 87ec881b..3e0c84c7 100644 --- a/thewalrus/__init__.py +++ b/thewalrus/__init__.py @@ -132,6 +132,12 @@ rec_torontonian, rec_ltorontonian, ) + +from ._montrealer import ( + mtl, + lmtl, +) + from ._version import __version__ @@ -152,6 +158,8 @@ "reduction", "hermite_multidimensional", "grad_hermite_multidimensional", + "mtl", + "lmtl", "version", ] diff --git a/thewalrus/_montrealer.py b/thewalrus/_montrealer.py new file mode 100644 index 00000000..89c331e2 --- /dev/null +++ b/thewalrus/_montrealer.py @@ -0,0 +1,134 @@ +""" +Montrealer Python interface + +* Yanic Cardin and Nicolás Quesada. "Photon-number moments and cumulants of Gaussian states" + `arxiv:12212.06067 (2023) `_ +""" +import numpy as np +import numba +from thewalrus.quantum.conversions import Xmat +from thewalrus.charpoly import powertrace +from ._hafnian import nb_ix +from ._torontonian import tor_input_checks + + +@numba.jit(nopython=True, cache=True) +def dec2bin(num, n): # pragma: no cover + """Helper function to convert an integer into an element of the power-set of ``n`` objects. + + Args: + num (int): label to convert + n (int): number of elements in the set + + Returns: + (array): array containing the labels of the elements to be selected + """ + digits = np.zeros((n), dtype=type(num)) + nn = num + counter = -1 + while nn >= 1: + digits[counter] = nn % 2 + counter -= 1 + nn //= 2 + return np.nonzero(digits)[0] + + +@numba.jit(nopython=True) +def montrealer(Sigma): # pragma: no cover + """Calculates the loop-montrealer of the zero-displacement Gaussian state with the given complex covariance matrix. + + Args: + Sigma (array): adjacency matrix of the Gaussian state + + Returns: + (np.complex128): the montrealer of ``Sigma`` + """ + n = len(Sigma) // 2 + tot_num = 2**n + val = np.complex128(0) + for p in numba.prange(tot_num): + pos = dec2bin(p, n) + lenpos = len(pos) + pos = np.concatenate((pos, n + pos)) + submat = nb_ix(Sigma, pos, pos) + sign = (-1) ** (lenpos + 1) + val += (sign) * powertrace(submat, n + 1)[-1] + return (-1) ** (n + 1) * val / (2 * n) + + +@numba.jit(nopython=True) +def power_loop(Sigma, zeta, n): # pragma: no cover + """Auxiliary function to calculate the product ``np.conj(zeta) @ Sigma^{n-1} @ zeta``. + + Args: + Sigma (array): square complex matrix + zeta (array): complex vector + n (int): sought after power + + Returns: + (np.complex128 or np.float64): the product np.conj(zeta) @ Sigma^{n-1} @ zeta + """ + vec = zeta + for _ in range(n - 1): + vec = Sigma @ vec + return np.conj(zeta) @ vec + + +@numba.jit(nopython=True, cache=True) +def lmontrealer(Sigma, zeta): # pragma: no cover + """Calculates the loop-montrealer of the displaced Gaussian state with the given complex covariance matrix and vector of displacements. + + Args: + Sigma (array): complex Glauber covariance matrix of the Gaussian state + zeta (array): vector of displacements + + Returns: + (np.complex128): the montrealer of ``Sigma`` + """ + n = len(Sigma) // 2 + tot_num = 2**n + val = np.complex128(0) + val_loops = np.complex128(0) + for p in numba.prange(tot_num): + pos = dec2bin(p, n) + lenpos = len(pos) + pos = np.concatenate((pos, n + pos)) + subvec = zeta[pos] + submat = nb_ix(Sigma, pos, pos) + sign = (-1) ** (lenpos + 1) + val_loops += sign * power_loop(submat, subvec, n) + val += sign * powertrace(submat, n + 1)[-1] + return (-1) ** (n + 1) * (val / (2 * n) + val_loops / 2) + + +def lmtl(A, zeta): + """Returns the montrealer of an NxN matrix and an N-length vector. + + Args: + A (array): an NxN array of even dimensions + zeta (array): an N-length vector of even dimensions + + Returns: + np.float64 or np.complex128: the loop montrealer of matrix A, vector zeta + """ + + tor_input_checks(A, zeta) + n = len(A) // 2 + Sigma = Xmat(n) @ A + return lmontrealer(Sigma, zeta) + + +def mtl(A): + """Returns the montrealer of an NxN matrix. + + Args: + A (array): an NxN array of even dimensions. + + Returns: + np.float64 or np.complex128: the montrealer of matrix ``A`` + """ + + tor_input_checks(A) + n = len(A) // 2 + Sigma = Xmat(n) @ A + return montrealer(Sigma) diff --git a/thewalrus/_torontonian.py b/thewalrus/_torontonian.py index 8626c46b..1679a513 100644 --- a/thewalrus/_torontonian.py +++ b/thewalrus/_torontonian.py @@ -20,19 +20,15 @@ from ._hafnian import reduction, find_kept_edges, nb_ix -def tor(A, recursive=True): - """Returns the Torontonian of a matrix. +def tor_input_checks(A, loops=None): + """Checks the correctness of the inputs for the torontonian/montrealer. Args: - A (array): a square array of even dimensions. - recursive: use the faster recursive implementation. - - Returns: - np.float64 or np.complex128: the torontonian of matrix A. + A (array): an NxN array of even dimensions + loops (array): optional argument, an N-length vector of even dimensions """ if not isinstance(A, np.ndarray): raise TypeError("Input matrix must be a NumPy array.") - matshape = A.shape if matshape[0] != matshape[1]: @@ -40,6 +36,25 @@ def tor(A, recursive=True): if matshape[0] % 2 != 0: raise ValueError("matrix dimension must be even") + + if loops is not None: + if not isinstance(loops, np.ndarray): + raise TypeError("Input matrix must be a NumPy array.") + if matshape[0] != len(loops): + raise ValueError("gamma must be a vector matching the dimension of A") + + +def tor(A, recursive=True): + """Returns the Torontonian of a matrix. + + Args: + A (array): a square array of even dimensions. + recursive: use the faster recursive implementation. + + Returns: + np.float64 or np.complex128: the torontonian of matrix A. + """ + tor_input_checks(A) return rec_torontonian(A) if recursive else numba_tor(A) @@ -54,23 +69,7 @@ def ltor(A, gamma, recursive=True): Returns: np.float64 or np.complex128: the loop torontonian of matrix A, vector gamma """ - - if not isinstance(A, np.ndarray): - raise TypeError("Input matrix must be a NumPy array.") - - if not isinstance(gamma, np.ndarray): - raise TypeError("Input matrix must be a NumPy array.") - - matshape = A.shape - - if matshape[0] != matshape[1]: - raise ValueError("Input matrix must be square.") - - if matshape[0] != len(gamma): - raise ValueError("gamma must be a vector matching the dimension of A") - - if matshape[0] % 2 != 0: - raise ValueError("matrix dimension must be even") + tor_input_checks(A, gamma) return rec_ltorontonian(A, gamma) if recursive else numba_ltor(A, gamma) diff --git a/thewalrus/_version.py b/thewalrus/_version.py index 8c079ffc..f839ff5d 100644 --- a/thewalrus/_version.py +++ b/thewalrus/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.21.0-dev" +__version__ = "0.22.0-dev" diff --git a/thewalrus/decompositions.py b/thewalrus/decompositions.py index fdcd49c7..70124d4e 100644 --- a/thewalrus/decompositions.py +++ b/thewalrus/decompositions.py @@ -36,7 +36,7 @@ """ import numpy as np -from scipy.linalg import block_diag, sqrtm, schur +from scipy.linalg import sqrtm, schur, polar from thewalrus.symplectic import sympmat from thewalrus.quantum.gaussian_checks import is_symplectic @@ -54,7 +54,7 @@ def williamson(V, rtol=1e-05, atol=1e-08): Returns: tuple[array,array]: ``(Db, S)`` where ``Db`` is a diagonal matrix - and ``S`` is a symplectic matrix such that :math:`V = S^T Db S` + and ``S`` is a symplectic matrix such that :math:`V = S Db S^T` """ (n, m) = V.shape @@ -74,28 +74,27 @@ def williamson(V, rtol=1e-05, atol=1e-08): if not np.all(vals > 0): raise ValueError("Input matrix is not positive definite") - Mm12 = sqrtm(np.linalg.inv(V)).real - r1 = Mm12 @ omega @ Mm12 - s1, K = schur(r1) - # In what follows a permutation matrix perm1 is constructed so that the Schur matrix has + M12 = np.real_if_close(sqrtm(V)) + Mm12 = np.linalg.inv(M12) + Gamma = Mm12 @ omega @ Mm12 + a, O = schur(Gamma) + # In what follows a permutation matrix perm is constructed so that the Schur matrix has # only positive elements above the diagonal - # Also the Schur matrix uses the x_1,p_1, ..., x_n,p_n ordering thus a permutation perm2 is used + # Also the Schur matrix uses the x_1,p_1, ..., x_n,p_n ordering thus the permutation perm is updated # to go to the ordering x_1, ..., x_n, p_1, ... , p_n - perm1 = np.arange(2 * n) + perm = np.arange(2 * n) for i in range(n): - if s1[2 * i, 2 * i + 1] <= 0: - (perm1[2 * i], perm1[2 * i + 1]) = (perm1[2 * i + 1], perm1[2 * i]) + if a[2 * i, 2 * i + 1] <= 0: + (perm[2 * i], perm[2 * i + 1]) = (perm[2 * i + 1], perm[2 * i]) - perm2 = np.array([perm1[2 * i] for i in range(n)] + [perm1[2 * i + 1] for i in range(n)]) + perm = np.array([perm[2 * i] for i in range(n)] + [perm[2 * i + 1] for i in range(n)]) - Ktt = K[:, perm2] - s1t = s1[:, perm1][perm1] - - dd = np.array([1 / s1t[2 * i, 2 * i + 1] for i in range(n)]) - dd = np.concatenate([dd, dd]) - ddsqrt = np.sqrt(dd) - S = Mm12 @ Ktt * ddsqrt - return np.diag(dd), np.linalg.inv(S).T + O = O[:, perm] + phi = np.abs(np.diag(a, k=1)[::2]) + dd = np.concatenate([1 / phi, 1 / phi]) + ddsqrt = 1 / np.sqrt(dd) + S = M12 @ O * ddsqrt + return np.diag(dd), S def symplectic_eigenvals(cov): @@ -107,67 +106,48 @@ def symplectic_eigenvals(cov): Returns: (array): symplectic eigenvalues """ - M = int(len(cov) / 2) + M = len(cov) // 2 Omega = sympmat(M) return np.real_if_close(-1j * np.linalg.eigvals(Omega @ cov))[::2] def blochmessiah(S): - """Returns the Bloch-Messiah decomposition of a symplectic matrix S = uff @ dff @ vff - where uff and vff are orthogonal symplectic matrices and dff is a diagonal matrix + """Returns the Bloch-Messiah decomposition of a symplectic matrix S = O @ D @ Q + where O and Q are orthogonal symplectic matrices and D is a positive-definite diagonal matrix of the form diag(d1,d2,...,dn,d1^-1, d2^-1,...,dn^-1), Args: S (array[float]): 2N x 2N real symplectic matrix Returns: - tuple(array[float], : orthogonal symplectic matrix uff - array[float], : diagonal matrix dff - array[float]) : orthogonal symplectic matrix vff + tuple(array[float], : orthogonal symplectic matrix O + array[float], : diagonal matrix D + array[float]) : orthogonal symplectic matrix Q """ N, _ = S.shape if not is_symplectic(S): raise ValueError("Input matrix is not symplectic.") - - # Changing Basis - R = (1 / np.sqrt(2)) * np.block( - [[np.eye(N // 2), 1j * np.eye(N // 2)], [np.eye(N // 2), -1j * np.eye(N // 2)]] - ) - Sc = R @ S @ np.conjugate(R).T - # Polar Decomposition - u1, d1, v1 = np.linalg.svd(Sc) - Sig = u1 @ np.diag(d1) @ np.conjugate(u1).T - Unitary = u1 @ v1 - # Blocks of Unitary and Hermitian symplectics - alpha = Unitary[0 : N // 2, 0 : N // 2] - beta = Sig[0 : N // 2, N // 2 : N] - # Bloch-Messiah in this Basis - d2, takagibeta = takagi(beta) - sval = np.arcsinh(d2) - uf = block_diag(takagibeta, takagibeta.conj()) - blc = np.conjugate(takagibeta).T @ alpha - vf = block_diag(blc, blc.conj()) - df = np.block( - [ - [np.diag(np.cosh(sval)), np.diag(np.sinh(sval))], - [np.diag(np.sinh(sval)), np.diag(np.cosh(sval))], - ] - ) - # Rotating Back to Original Basis - uff = np.conjugate(R).T @ uf @ R - vff = np.conjugate(R).T @ vf @ R - dff = np.conjugate(R).T @ df @ R - dff = np.real_if_close(dff) - vff = np.real_if_close(vff) - uff = np.real_if_close(uff) - return uff, dff, vff + N = N // 2 + V, P = polar(S, side="left") + A = P[:N, :N] + B = P[:N, N:] + C = P[N:, N:] + M = A - C + 1j * (B + B.T) + Lam, W = takagi(M) + Lam = 0.5 * Lam + O = np.block([[W.real, -W.imag], [W.imag, W.real]]) + Q = O.T @ V + sqrt1pLam2 = np.sqrt(1 + Lam**2) + D = np.diag(np.concatenate([sqrt1pLam2 + Lam, sqrt1pLam2 - Lam])) + return O, D, Q def takagi(A, svd_order=True): r"""Autonne-Takagi decomposition of a complex symmetric (not Hermitian!) matrix. - Note that the input matrix is internally symmetrized. If the input matrix is indeed symmetric this leaves it unchanged. + Note that the input matrix is internally symmetrized by taking its upper triangular part. + If the input matrix is indeed symmetric this leaves it unchanged. See `Carl Caves note. `_ Args: @@ -182,8 +162,8 @@ def takagi(A, svd_order=True): n, m = A.shape if n != m: raise ValueError("The input matrix is not square") - # Here we force symmetrize the matrix - A = 0.5 * (A + A.T) + # Here we build a Symmetric matrix from the top right triangular part + A = np.triu(A) + np.triu(A, k=1).T A = np.real_if_close(A) @@ -193,26 +173,21 @@ def takagi(A, svd_order=True): if np.isrealobj(A): # If the matrix A is real one can be more clever and use its eigendecomposition ls, U = np.linalg.eigh(A) - U = U / np.exp(1j * np.angle(U)[0]) vals = np.abs(ls) # These are the Takagi eigenvalues - phases = -np.ones(vals.shape[0], dtype=np.complex128) - for j, l in enumerate(ls): - if np.allclose(l, 0) or l > 0: - phases[j] = 1 - phases = np.sqrt(phases) - Uc = U @ np.diag(phases) # One needs to readjust the phases - signs = np.sign(Uc.real)[0] - for k, s in enumerate(signs): - if np.allclose(s, 0): - signs[k] = 1 - Uc = np.real_if_close(Uc / signs) - list_vals = [(vals[i], i) for i in range(len(vals))] - # And also rearrange the unitary and values so that they are decreasingly ordered - list_vals.sort(reverse=svd_order) - sorted_ls, permutation = zip(*list_vals) - return np.array(sorted_ls), Uc[:, np.array(permutation)] - - phi = np.angle(A[0, 0]) + signs = (-1) ** (1 + np.heaviside(ls, 1)) + phases = np.sqrt(np.complex128(signs)) + Uc = U * phases # One needs to readjust the phases + # Find the permutation to sort in decreasing order + perm = np.argsort(vals) + # if svd_order reverse it + if svd_order: + perm = perm[::-1] + return vals[perm], Uc[:, perm] + + # Find the element with the largest absolute value + pos = np.unravel_index(np.argmax(np.abs(A)), (n, n)) + # Use it to find whether the input is a global phase times a real matrix + phi = np.angle(A[pos]) Amr = np.real_if_close(np.exp(-1j * phi) * A) if np.isrealobj(Amr): vals, U = takagi(Amr, svd_order=svd_order) @@ -220,10 +195,6 @@ def takagi(A, svd_order=True): u, d, v = np.linalg.svd(A) U = u @ sqrtm((v @ np.conjugate(u)).T) - # The line above could be simplifed to the line below if the product v @ np.conjugate(u) is diagonal - # Which it should be according to Caves http://info.phys.unm.edu/~caves/courses/qinfo-s17/lectures/polarsingularAutonne.pdf - # U = u * np.sqrt(0j + np.diag(np.conjugate(u) @ v)) - # This however breaks test_degenerate if svd_order is False: return d[::-1], U[:, ::-1] return d, U diff --git a/thewalrus/reference.py b/thewalrus/reference.py index 00d7934e..9bd2e041 100644 --- a/thewalrus/reference.py +++ b/thewalrus/reference.py @@ -34,12 +34,14 @@ .. autosummary:: hafnian + montrealer Code details ------------ .. autofunction:: hafnian + montrealer Auxiliary functions ------------------- @@ -49,6 +51,8 @@ partitions spm pmp + rspm + rpmp T Code details @@ -58,7 +62,7 @@ # pylint: disable=too-many-arguments from collections import OrderedDict -from itertools import tee +from itertools import tee, product, permutations, chain from types import GeneratorType MAXSIZE = 1000 @@ -278,3 +282,121 @@ def hafnian(M, loop=False): tot_sum = tot_sum + result return tot_sum + + +def mapper(x, objects): + """Helper function to turn a permutation and bistring into an element of rpmp. + + Args: + x (tuple): tuple containing a permutation and a bistring + objects (list): list objects to permute + + Returns: + tuple: permuted objects + """ + (perm, bit) = x + m = len(bit) + Blist = [list(range(m)), list(range(m, 2 * m))] + for i, j in enumerate(bit): + if int(j): + (Blist[0][i], Blist[1][i]) = (Blist[1][i], Blist[0][i]) + Blist = [Blist[0][i] for i in tuple((0,) + perm)] + [Blist[1][i] for i in tuple((0,) + perm)] + dico_list = {j: i + 1 for i, j in enumerate(Blist)} + new_mapping_list = { + objects[dico_list[i] - 1]: objects[dico_list[j] - 1] + for i, j in zip(list(range(0, m - 1)) + [m], list(range(m + 1, 2 * m)) + [m - 1]) + } + return tuple(new_mapping_list.items()) + + +def bitstrings(n): + """Returns the bistrings from 0 to n/2 + + Args: + n (int): Twice the highest bitstring value. + + Returns: + (iterator): An iterable of all bistrings. + """ + for binary in map("".join, product("01", repeat=n - 1)): + yield "0" + binary + + +def rpmp(s): + """Generates the restricted set of perfect matchings matching permutations. + + Args: + s (tuple): tuple of labels to be used + + Returns: + generator: the set of restricted perfect matching permutations of the tuple ``s`` + """ + m = len(s) // 2 + + def local_mapper(x): + """Helper function to define a local mapper based on the symbols s + Args: + x (iterable): object to be mapped + """ + return mapper(x, s) + + for i in product(permutations(range(1, m)), bitstrings(m)): + yield local_mapper(i) + + +def splitter(elem): + """Takes an element from the restricted perfect matching permutations (rpmp) and returns all the associated elements in the restricted single pair matchings (rspm) + + Args: + elem (tuple): tuple representing an element of rpmp + + Returns: + (iterator): all the associated elements in rspm + """ + num_elem = len(elem) + net = [elem] + for i in range(num_elem): + left = (elem[j] for j in range(i)) + middle_left = ((elem[i][0], elem[i][0]),) + middle_right = ((elem[i][1], elem[i][1]),) + right = (elem[j] for j in range(i + 1, num_elem)) + net.append(tuple(middle_right) + tuple(right) + tuple(left) + tuple(middle_left)) + for i in net: + yield i + + +def rspm(s): + """Generates the restricted set of single-pair matchings. + + Args: + s (tuple): tuple of labels to be used + + Returns: + generator: the set of restricted perfect matching permutations of the tuple s + """ + gen = rpmp(s) + return chain(*(splitter(i) for i in gen)) + + +def mtl(A, loop=False): + """Returns the Montrealer of an NxN matrix and an N-length vector. + + Args: + A (array): an NxN array of even dimensions. Can be symbolic. + loop (boolean): if set to ``True``, the loop montrealer is returned + + Returns: + np.float64, np.complex128 or sympy.core.add.Add: the Montrealer of matrix ``A``. + """ + n, _ = A.shape + net_sum = 0 + + perm = rspm(range(n)) if loop else rpmp(range(n)) + for s in perm: + net_prod = 1 + for a in s: + net_prod *= A[a[0], a[1]] + + net_sum += net_prod + + return net_sum diff --git a/thewalrus/tests/test_decompositions.py b/thewalrus/tests/test_decompositions.py index e6684df1..a2dd9e11 100644 --- a/thewalrus/tests/test_decompositions.py +++ b/thewalrus/tests/test_decompositions.py @@ -274,23 +274,30 @@ def test_takagi(n, datatype, svd_order): assert np.all(np.diff(r) >= 0) +# pylint: disable=too-many-arguments @pytest.mark.parametrize("n", [5, 10, 50]) @pytest.mark.parametrize("datatype", [np.complex128, np.float64]) @pytest.mark.parametrize("svd_order", [True, False]) @pytest.mark.parametrize("half_rank", [0, 1]) @pytest.mark.parametrize("phase", [0, 1]) -def test_degenerate(n, datatype, svd_order, half_rank, phase): +@pytest.mark.parametrize("null_space", [0, 5, 10]) +@pytest.mark.parametrize("offset", [0, 0.5]) +def test_degenerate(n, datatype, svd_order, half_rank, phase, null_space, offset): """Tests Takagi produces the correct result for very degenerate cases""" nhalf = n // 2 - diags = [half_rank * np.random.rand()] * nhalf + [np.random.rand()] * (n - nhalf) + diags = ( + [half_rank * np.random.rand()] * nhalf + + [np.random.rand() - offset] * (n - nhalf) + + [0] * null_space + ) if datatype is np.complex128: - U = haar_measure(n) + U = haar_measure(n + null_space) if datatype is np.float64: - U = np.exp(1j * phase) * haar_measure(n, real=True) + U = np.exp(1j * phase) * haar_measure(n + null_space, real=True) A = U @ np.diag(diags) @ U.T r, U = takagi(A, svd_order=svd_order) assert np.allclose(A, U @ np.diag(r) @ U.T) - assert np.allclose(U @ U.T.conj(), np.eye(n)) + assert np.allclose(U @ U.T.conj(), np.eye(n + null_space)) assert np.all(r >= 0) if svd_order is True: assert np.all(np.diff(r) <= 0) @@ -394,3 +401,43 @@ def test_real_degenerate(): rl, U = takagi(mat) assert np.allclose(U @ U.conj().T, np.eye(len(mat))) assert np.allclose(U @ np.diag(rl) @ U.T, mat) + + +@pytest.mark.parametrize("n", [5, 10, 50]) +@pytest.mark.parametrize("datatype", [np.complex128, np.float64]) +@pytest.mark.parametrize("svd_order", [True, False]) +def test_autonne_takagi(n, datatype, svd_order): + """Checks the correctness of the Autonne decomposition function""" + if datatype is np.complex128: + A = np.random.rand(n, n) + 1j * np.random.rand(n, n) + if datatype is np.float64: + A = np.random.rand(n, n) + A += A.T + r, U = takagi(A, svd_order=svd_order) + assert np.allclose(A, U @ np.diag(r) @ U.T) + assert np.all(r >= 0) + if svd_order is True: + assert np.all(np.diff(r) <= 0) + else: + assert np.all(np.diff(r) >= 0) + + +@pytest.mark.parametrize("size", [10, 20, 100]) +def test_flat_phase(size): + """Test that the correct decomposition is obtained even if the first entry is 0""" + A = np.random.rand(size, size) + 1j * np.random.rand(size, size) + A += A.T + A[0, 0] = 0 + l, u = takagi(A) + assert np.allclose(A, u * l @ u.T) + + +def test_real_input_edge(): + """Adapted from https://math.stackexchange.com/questions/4418925/why-does-this-algorithm-for-the-takagi-factorization-fail-here""" + rng = np.random.default_rng(0) # Important for reproducibility + A = (rng.random((100, 100)) - 0.5) * 114 + A = A * A.T # make A symmetric + l, u = takagi(A) + # Now, reconstruct A, see + Ar = u * l @ u.T + assert np.allclose(A, Ar) diff --git a/thewalrus/tests/test_montrealer.py b/thewalrus/tests/test_montrealer.py new file mode 100644 index 00000000..f511d0fc --- /dev/null +++ b/thewalrus/tests/test_montrealer.py @@ -0,0 +1,181 @@ +# Copyright 2021 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain adj copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Montrealer tests +Yanic Cardin and Nicolás Quesada. "Photon-number moments and cumulants of Gaussian states" +`arxiv:12212.06067 (2023) `_ +""" + +import pytest +import numpy as np +from thewalrus import mtl, lmtl +from thewalrus.reference import mapper +from thewalrus.quantum import Qmat, Xmat +from thewalrus.reference import rspm, rpmp, mtl as mtl_symb +from thewalrus.random import random_covariance +from scipy.special import factorial2 +from scipy.stats import unitary_group + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_montrealer_all_ones(n): + """Test that the Montrealer of a matrix of ones gives (2n-2)!!""" + adj = np.ones([2 * n, 2 * n]) + mtl_val = mtl(adj) + mtl_expect = factorial2(2 * n - 2) + assert np.allclose(mtl_val, mtl_expect) + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_loop_montrealer_all_ones(n): + """Test that the loop Montrealer of a matrix of ones gives (n+1)(2n-2)!!""" + adj = np.ones([2 * n, 2 * n]) + lmtl_val = lmtl(adj, zeta=np.diag(adj)) + lmtl_expect = (n + 1) * factorial2(2 * n - 2) + assert np.allclose(lmtl_val, lmtl_expect) + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_size_of_rpmp(n): + """rpmp(2n) should have (2n-2)!! elements""" + terms_rpmp = len(list(rpmp(range(2 * n)))) + terms_theo = factorial2(2 * n - 2) + assert terms_rpmp == terms_theo + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_size_of_rspm(n): + """rspm(2n) should have (n+1)(2n-2)!! elements""" + terms_rspm = sum(1 for _ in rspm(range(2 * n))) + terms_theo = (n + 1) * factorial2(2 * n - 2) + assert terms_rspm == terms_theo + + +@pytest.mark.parametrize("n", range(2, 8)) +def test_rpmp_alternating_walk(n): + """The rpmp must form a Y-alternating walk without loops""" + test = True + for perfect in rpmp(range(1, 2 * n + 1)): + last = perfect[0][1] # starting point + reduced_last = last - n if last > n else last + # different mode in every tuple + if reduced_last == 1: + test = False + + for i in perfect[1:]: + reduced = i[0] - n if i[0] > n else i[0], i[1] - n if i[1] > n else i[1] + # different mode in every tuple + if reduced[0] == reduced[1]: + test = False + # consecutive tuple contain the same mode + if reduced_last not in reduced: + test = False + + last = i[0] if reduced[1] == reduced_last else i[1] + reduced_last = last - n if last > n else last + + # last mode most coincide with the first one + if reduced_last != 1: + test = False + + assert test + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_mtl_functions_agree(n): + """Make sure both mtl functions agree with one another""" + V = random_covariance(n) + Aad = Xmat(n) @ (Qmat(V) - np.identity(2 * n)) + assert np.allclose(mtl_symb(Aad), mtl(Aad)) + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_lmtl_functions_agree(n): + """Make sure both lmtl functions agree with one another""" + V = random_covariance(n) + Aad = Xmat(n) @ (Qmat(V) - np.identity(2 * n)) + zeta = np.diag(Aad).conj() + assert np.allclose(lmtl(Aad, zeta), mtl_symb(Aad, loop=True)) + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_mtl_lmtl_agree(n): + """Make sure mtl and lmtl give the same result if zeta = 0""" + V = random_covariance(n) + Aad = Xmat(n) @ (Qmat(V) - np.identity(2 * n)) + zeta = np.zeros(2 * n, dtype=np.complex128) + assert np.allclose(lmtl(Aad, zeta), lmtl(Aad, zeta)) + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_mtl_lmtl_reference_agree(n): + """Make sure mtl and lmtl from .reference give the same result if zeta = 0""" + V = random_covariance(n) + Aad = Xmat(n) @ (Qmat(V) - np.identity(2 * n)) + zeta = np.zeros(2 * n, dtype=np.complex128) + np.fill_diagonal(Aad, zeta) + assert np.allclose(mtl_symb(Aad, loop=True), mtl_symb(Aad)) + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_mtl_permutation(n): + """Make sure the mtl is invariant under permutation + cf. Eq. 44 of `arxiv:12212.06067 (2023) `_""" + V = random_covariance(n) + Aad = Xmat(n) @ (Qmat(V) - np.identity(2 * n)) + perm = np.random.permutation(n) + perm = np.concatenate((perm, [i + n for i in perm])) + assert np.allclose(mtl(Aad), mtl(Aad[perm][:, perm])) + + +@pytest.mark.parametrize("n", range(2, 5)) +def test_mtl_associated_adjacency(n): + """Make sure the mtl of a matrix in which each block is block diaognal is zero. + cf. Eq. 45 of `arxiv:12212.06067 (2023) `_""" + u_zero = np.zeros((n, n), dtype=np.complex128) + + u_n1 = unitary_group.rvs(n) + u_n2 = unitary_group.rvs(n) + u_n = np.block([[u_n1, u_zero], [u_zero, u_n2]]) + u_n = u_n + u_n.conj().T + + u_m1 = unitary_group.rvs(n) + u_m2 = unitary_group.rvs(n) + u_m = np.block([[u_m1, u_zero], [u_zero, u_m2]]) + u_m_r = u_m + u_m.T + + u_m3 = unitary_group.rvs(n) + u_m4 = unitary_group.rvs(n) + u_m = np.block([[u_m3, u_zero], [u_zero, u_m4]]) + u_m_l = u_m + u_m.T + + adj = np.block([[u_m_r, u_n], [u_n.T, u_m_l]]) + + assert np.allclose(mtl(adj), 0) + + +@pytest.mark.parametrize("n", range(1, 8)) +def test_mtl_diagonal_trace(n): + """Make sure the mtl of A times a diagonal matrix gives the product of the norms of the diagonal matrix times the mtl of A + cf. Eq. 41 of `arxiv:12212.06067 (2023) `_""" + gamma = np.random.uniform(-1, 1, n) + 1.0j * np.random.uniform(-1, 1, n) + product = np.prod([abs(i) ** 2 for i in gamma]) + gamma = np.diag(np.concatenate((gamma, gamma.conj()))) + V = random_covariance(n) + Aad = Xmat(n) @ (Qmat(V) - np.identity(2 * n)) + assert np.allclose(mtl(gamma @ Aad @ gamma), product * mtl(Aad)) + + +def test_mapper_hard_coded(): + """Tests the the mapper function for a particular hardcoded value""" + assert mapper(((1, 2, 3), "0000"), (0, 1, 2, 3, 4, 5, 6, 7)) == ((0, 5), (1, 6), (2, 7), (4, 3))