diff --git a/README.md b/README.md index 60befbf..965d8c7 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ # qujax Represent a (parameterised) quantum circuit as a pure [JAX](https://github.com/google/jax) function that -takes as input any parameters of the circuit and outputs a _statetensor_. The statetensor encodes all $2^N$ amplitudes of the quantum state and can then be used -downstream for exact expectations, gradients or sampling. +takes as input any parameters of the circuit and outputs a _statetensor_. The statetensor encodes all $2^N$ amplitudes +of the quantum state and can then be used downstream for exact expectations, gradients or sampling. -A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support for GPUs/TPUs. +qujax also supports densitytensor simulations. A densitytensor is a tensor representation of the density matrix and allows for mixed states and generic Kraus operators. + +A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support +for GPUs/TPUs. Some useful links: - [Documentation](https://cqcl.github.io/qujax/api/) diff --git a/docs/conf.py b/docs/conf.py index c58ed0e..00bb712 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,7 +17,7 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ['sphinx.ext.autodoc', 'sphinx_rtd_theme', 'sphinx.ext.napoleon'] +extensions = ['sphinx.ext.autodoc', 'sphinx_rtd_theme', 'sphinx.ext.napoleon', 'sphinx.ext.mathjax'] templates_path = ['_templates'] exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] @@ -38,3 +38,5 @@ 'Callable[[Optional[ndarray]], ndarray]]' } + +latex_engine = 'pdflatex' diff --git a/docs/densitytensor.rst b/docs/densitytensor.rst new file mode 100644 index 0000000..8dbce28 --- /dev/null +++ b/docs/densitytensor.rst @@ -0,0 +1,14 @@ +densitytensor +======================= + +.. toctree:: + + densitytensor/kraus + densitytensor/get_params_to_densitytensor_func + densitytensor/partial_trace + densitytensor/get_densitytensor_to_expectation_func + densitytensor/get_densitytensor_to_sampled_expectation_func + densitytensor/densitytensor_to_measurement_probabilities + densitytensor/densitytensor_to_measured_densitytensor + densitytensor/statetensor_to_densitytensor + diff --git a/docs/densitytensor/densitytensor_to_measured_densitytensor.rst b/docs/densitytensor/densitytensor_to_measured_densitytensor.rst new file mode 100644 index 0000000..7c52298 --- /dev/null +++ b/docs/densitytensor/densitytensor_to_measured_densitytensor.rst @@ -0,0 +1,5 @@ +densitytensor_to_measured_densitytensor +============================================== + +.. autofunction:: qujax.densitytensor_to_measured_densitytensor + diff --git a/docs/densitytensor/densitytensor_to_measurement_probabilities.rst b/docs/densitytensor/densitytensor_to_measurement_probabilities.rst new file mode 100644 index 0000000..50beb3a --- /dev/null +++ b/docs/densitytensor/densitytensor_to_measurement_probabilities.rst @@ -0,0 +1,5 @@ +densitytensor_to_measurement_probabilities +============================================== + +.. autofunction:: qujax.densitytensor_to_measurement_probabilities + diff --git a/docs/densitytensor/get_densitytensor_to_expectation_func.rst b/docs/densitytensor/get_densitytensor_to_expectation_func.rst new file mode 100644 index 0000000..c4a2d09 --- /dev/null +++ b/docs/densitytensor/get_densitytensor_to_expectation_func.rst @@ -0,0 +1,5 @@ +get_densitytensor_to_expectation_func +======================================= + +.. autofunction:: qujax.get_densitytensor_to_expectation_func + diff --git a/docs/densitytensor/get_densitytensor_to_sampled_expectation_func.rst b/docs/densitytensor/get_densitytensor_to_sampled_expectation_func.rst new file mode 100644 index 0000000..c71cdbe --- /dev/null +++ b/docs/densitytensor/get_densitytensor_to_sampled_expectation_func.rst @@ -0,0 +1,5 @@ +get_densitytensor_to_sampled_expectation_func +================================================ + +.. autofunction:: qujax.get_densitytensor_to_sampled_expectation_func + diff --git a/docs/densitytensor/get_params_to_densitytensor_func.rst b/docs/densitytensor/get_params_to_densitytensor_func.rst new file mode 100644 index 0000000..554b887 --- /dev/null +++ b/docs/densitytensor/get_params_to_densitytensor_func.rst @@ -0,0 +1,5 @@ +get_params_to_densitytensor_func +=================================== + +.. autofunction:: qujax.get_params_to_densitytensor_func + diff --git a/docs/densitytensor/kraus.rst b/docs/densitytensor/kraus.rst new file mode 100644 index 0000000..8f3959e --- /dev/null +++ b/docs/densitytensor/kraus.rst @@ -0,0 +1,5 @@ +kraus +======================= + +.. autofunction:: qujax.kraus + diff --git a/docs/densitytensor/partial_trace.rst b/docs/densitytensor/partial_trace.rst new file mode 100644 index 0000000..e432d02 --- /dev/null +++ b/docs/densitytensor/partial_trace.rst @@ -0,0 +1,5 @@ +partial_trace +=============================== + +.. autofunction:: qujax.partial_trace + diff --git a/docs/densitytensor/statetensor_to_densitytensor.rst b/docs/densitytensor/statetensor_to_densitytensor.rst new file mode 100644 index 0000000..3b66c4d --- /dev/null +++ b/docs/densitytensor/statetensor_to_densitytensor.rst @@ -0,0 +1,5 @@ +statetensor_to_densitytensor +=============================== + +.. autofunction:: qujax.statetensor_to_densitytensor + diff --git a/docs/index.rst b/docs/index.rst index c52f038..2a75cb1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Docs sample_bitstrings check_circuit print_circuit + densitytensor gates diff --git a/qujax/__init__.py b/qujax/__init__.py index 9af7a26..487f872 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -2,22 +2,39 @@ from qujax import gates -from qujax.circuit import UnionCallableOptionalArray -from qujax.circuit import apply_gate -from qujax.circuit import get_params_to_statetensor_func +from qujax.statetensor import apply_gate +from qujax.statetensor import get_params_to_statetensor_func -from qujax.observable import get_statetensor_to_expectation_func -from qujax.observable import get_statetensor_to_sampled_expectation_func -from qujax.observable import integers_to_bitstrings -from qujax.observable import bitstrings_to_integers -from qujax.observable import sample_integers -from qujax.observable import sample_bitstrings +from qujax.statetensor_observable import statetensor_to_single_expectation +from qujax.statetensor_observable import get_statetensor_to_expectation_func +from qujax.statetensor_observable import get_statetensor_to_sampled_expectation_func -from qujax.circuit_tools import check_unitary -from qujax.circuit_tools import check_circuit -from qujax.circuit_tools import print_circuit +from qujax.densitytensor import _kraus_single +from qujax.densitytensor import kraus +from qujax.densitytensor import get_params_to_densitytensor_func +from qujax.densitytensor import partial_trace + +from qujax.densitytensor_observable import densitytensor_to_single_expectation +from qujax.densitytensor_observable import get_densitytensor_to_expectation_func +from qujax.densitytensor_observable import get_densitytensor_to_sampled_expectation_func +from qujax.densitytensor_observable import densitytensor_to_measurement_probabilities +from qujax.densitytensor_observable import densitytensor_to_measured_densitytensor + +from qujax.utils import UnionCallableOptionalArray +from qujax.utils import check_unitary +from qujax.utils import check_hermitian +from qujax.utils import check_circuit +from qujax.utils import print_circuit +from qujax.utils import integers_to_bitstrings +from qujax.utils import bitstrings_to_integers +from qujax.utils import sample_integers +from qujax.utils import sample_bitstrings +from qujax.utils import statetensor_to_densitytensor del version -del circuit -del observable -del circuit_tools +del statetensor +del statetensor_observable +del densitytensor +del densitytensor_observable +del utils + diff --git a/qujax/densitytensor.py b/qujax/densitytensor.py new file mode 100644 index 0000000..394f991 --- /dev/null +++ b/qujax/densitytensor.py @@ -0,0 +1,200 @@ +from __future__ import annotations +from typing import Sequence, Union, Callable, Iterable, Tuple +from jax import numpy as jnp +from jax.lax import scan + +from qujax.statetensor import apply_gate, UnionCallableOptionalArray +from qujax.statetensor import _to_gate_func, _arrayify_inds, _gate_func_to_unitary +from qujax.utils import check_circuit, kraus_op_type + + +def _kraus_single(densitytensor: jnp.ndarray, + array: jnp.ndarray, + qubit_inds: Sequence[int]) -> jnp.ndarray: + r""" + Performs single Kraus operation + + .. math:: + + \rho_\text{out} = B \rho_\text{in} B^{\dagger} + + Args: + densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits + array: Array containing the Kraus operator (in tensor form). + qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. + + Returns: + Updated density matrix. + """ + n_qubits = densitytensor.ndim // 2 + densitytensor = apply_gate(densitytensor, array, qubit_inds) + densitytensor = apply_gate(densitytensor, array.conj(), [n_qubits + i for i in qubit_inds]) + return densitytensor + + +def kraus(densitytensor: jnp.ndarray, + arrays: Iterable[jnp.ndarray], + qubit_inds: Sequence[int]) -> jnp.ndarray: + r""" + Performs Kraus operation. + + .. math:: + \rho_\text{out} = \sum_i B_i \rho_\text{in} B_i^{\dagger} + + Args: + densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits + arrays: Sequence of arrays containing the Kraus operators (in tensor form). + qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. + + Returns: + Updated density matrix. + """ + arrays = jnp.array(arrays) + if arrays.ndim % 2 == 0: + arrays = arrays[jnp.newaxis] + # ensure first dimensions indexes different kraus operators + arrays = arrays.reshape((arrays.shape[0],) + (2,) * 2 * len(qubit_inds)) + + new_densitytensor, _ = scan(lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None), + init=jnp.zeros_like(densitytensor) * 0.j, xs=arrays) + # i.e. new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))(densitytensor, arrays, qubit_inds).sum(0) + return new_densitytensor + + +def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type, + param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]) \ + -> Tuple[Sequence[Callable[[jnp.ndarray], jnp.ndarray]], + Sequence[jnp.ndarray]]: + """ + Ensures Kraus operators are a sequence of functions that map (possibly empty) parameters to tensors + and that each element of param_inds_seq is a sequence of arrays that correspond to the parameter indices + of each Kraus operator. + + Args: + kraus_op: Either a normal gate_type or a sequence of gate_types representing Kraus operators. + param_inds: If kraus_op is a normal gate_type then a sequence of parameter indices, + if kraus_op is a sequence of Kraus operators then a sequence of sequences of parameter indices + + Returns: + Tuple containing sequence of functions mapping to Kraus operators + and sequence of arrays with parameter indices + + """ + if param_inds is None: + param_inds = [None for _ in kraus_op] + + if isinstance(kraus_op, (list, tuple)): + kraus_op_funcs = [_to_gate_func(ko) for ko in kraus_op] + else: + kraus_op_funcs = [_to_gate_func(kraus_op)] + param_inds = [param_inds] + return kraus_op_funcs, _arrayify_inds(param_inds) + + +def partial_trace(densitytensor: jnp.ndarray, + indices_to_trace: Sequence[int]) -> jnp.ndarray: + """ + Traces out (discards) specified qubits, resulting in a densitytensor + representing the mixed quantum state on the remaining qubits. + + Args: + densitytensor: Input densitytensor. + indices_to_trace: Indices of qubits to trace out/discard. + + Returns: + Resulting densitytensor on remaining qubits. + + """ + n_qubits = densitytensor.ndim // 2 + einsum_indices = list(range(densitytensor.ndim)) + for i in indices_to_trace: + einsum_indices[i + n_qubits] = einsum_indices[i] + densitytensor = jnp.einsum(densitytensor, einsum_indices) + return densitytensor + + +def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[kraus_op_type], + qubit_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Union[None, Sequence[int], Sequence[Sequence[int]]]], + n_qubits: int = None) -> UnionCallableOptionalArray: + """ + Creates a function that maps circuit parameters to a density tensor (a density matrix in tensor form). + densitytensor = densitymatrix.reshape((2,) * 2 * n_qubits) + densitymatrix = densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits) + + Args: + kraus_ops_seq: Sequence of gates. + Each element is either a string matching a unitary array or function in qujax.gates, + a custom unitary array or a custom function taking parameters and returning a unitary array. + Unitary arrays will be reshaped into tensor form (2, 2,...) + qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on. + i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit, + the second gate is a two qubit gate acting on the zeroth and first qubit etc. + param_inds_seq: Sequence of sequences representing parameter indices that gates are using, + i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter + (the float at position zero in the parameter vector/array), the second gate is not parameterised + and the third gates used the parameters at position five and two. + n_qubits: Number of qubits, if fixed. + + Returns: + Function which maps parameters (and optional densitytensor_in) to a densitytensor. + If no parameters are found then the function only takes optional densitytensor_in. + + """ + + check_circuit(kraus_ops_seq, qubit_inds_seq, param_inds_seq, n_qubits, False) + + if n_qubits is None: + n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 + + kraus_ops_seq_callable_and_param_inds = [_to_kraus_operator_seq_funcs(ko, param_inds) + for ko, param_inds in zip(kraus_ops_seq, param_inds_seq)] + kraus_ops_seq_callable = [ko_pi[0] for ko_pi in kraus_ops_seq_callable_and_param_inds] + param_inds_array_seq = [ko_pi[1] for ko_pi in kraus_ops_seq_callable_and_param_inds] + + def params_to_densitytensor_func(params: jnp.ndarray, + densitytensor_in: jnp.ndarray = None) -> jnp.ndarray: + """ + Applies parameterised circuit (series of gates) to a densitytensor_in (default is |0>^N <0|^N). + + Args: + params: Parameters of the circuit. + densitytensor_in: Optional. Input densitytensor. + Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in [0]*(2*N) index). + + Returns: + Updated densitytensor. + + """ + if densitytensor_in is None: + densitytensor = jnp.zeros((2,) * 2 * n_qubits) + densitytensor = densitytensor.at[(0,) * 2 * n_qubits].set(1.) + else: + densitytensor = densitytensor_in + params = jnp.atleast_1d(params) + for gate_func_single_seq, qubit_inds, param_inds_single_seq in zip(kraus_ops_seq_callable, qubit_inds_seq, + param_inds_array_seq): + kraus_operators = [_gate_func_to_unitary(gf, qubit_inds, pi, params) + for gf, pi in zip(gate_func_single_seq, param_inds_single_seq)] + densitytensor = kraus(densitytensor, kraus_operators, qubit_inds) + return densitytensor + + non_parameterised = all([all([pi.size == 0 for pi in pi_seq]) for pi_seq in param_inds_array_seq]) + if non_parameterised: + def no_params_to_densitytensor_func(densitytensor_in: jnp.ndarray = None) -> jnp.ndarray: + """ + Applies circuit (series of gates with no parameters) to a densitytensor_in (default is |0>^N <0|^N). + + Args: + densitytensor_in: Optional. Input densitytensor. + Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in [0]*(2*N) index). + + Returns: + Updated densitytensor. + + """ + return params_to_densitytensor_func(jnp.array([]), densitytensor_in) + + return no_params_to_densitytensor_func + + return params_to_densitytensor_func diff --git a/qujax/densitytensor_observable.py b/qujax/densitytensor_observable.py new file mode 100644 index 0000000..2de7079 --- /dev/null +++ b/qujax/densitytensor_observable.py @@ -0,0 +1,157 @@ +from __future__ import annotations +from typing import Sequence, Union, Callable +from jax import numpy as jnp, random +from jax.lax import fori_loop + +from qujax.densitytensor import _kraus_single, partial_trace +from qujax.statetensor_observable import _get_tensor_to_expectation_func +from qujax.utils import sample_integers, statetensor_to_densitytensor, bitstrings_to_integers + + +def densitytensor_to_single_expectation(densitytensor: jnp.ndarray, + hermitian: jnp.ndarray, + qubit_inds: Sequence[int]) -> float: + """ + Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). + + Args: + densitytensor: Input densitytensor. + hermitian: Hermitian matrix representing observable + must be in tensor form with shape (2,2,...). + qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. + Must have 2 * len(qubit_inds) == hermitian.ndim + Returns: + Expected value (float). + """ + n_qubits = densitytensor.ndim // 2 + dt_indices = 2 * list(range(n_qubits)) + hermitian_indices = [i + densitytensor.ndim // 2 for i in range(hermitian.ndim)] + for n, q in enumerate(qubit_inds): + dt_indices[q] = hermitian_indices[n + len(qubit_inds)] + dt_indices[q + n_qubits] = hermitian_indices[n] + return jnp.einsum(densitytensor, dt_indices, hermitian, hermitian_indices).real + + +def get_densitytensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray], float]: + """ + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a densitytensor into an expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) + or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes densitytensor and returns expected value (float). + """ + + return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, + densitytensor_to_single_expectation) + + +def get_densitytensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: + """ + Converts strings (or arrays) representing Hermitian matrices, qubit indices and + coefficients into a function that converts a densitytensor into a sampled expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes densitytensor, random key and integer number of shots + and returns sampled expected value (float). + """ + densitytensor_to_expectation_func = get_densitytensor_to_expectation_func(hermitian_seq_seq, + qubits_seq_seq, + coefficients) + + def densitytensor_to_sampled_expectation_func(statetensor: jnp.ndarray, + random_key: random.PRNGKeyArray, + n_samps: int) -> float: + """ + Maps statetensor to sampled expected value. + + Args: + statetensor: Input statetensor. + random_key: JAX random key + n_samps: Number of samples contributing to sampled expectation. + + Returns: + Sampled expected value (float). + + """ + sampled_integers = sample_integers(random_key, statetensor, n_samps) + sampled_probs = fori_loop(0, n_samps, + lambda i, sv: sv.at[sampled_integers[i]].add(1), + jnp.zeros(statetensor.size)) + + sampled_probs /= n_samps + sampled_dt = statetensor_to_densitytensor(jnp.sqrt(sampled_probs).reshape(statetensor.shape)) + return densitytensor_to_expectation_func(sampled_dt) + + return densitytensor_to_sampled_expectation_func + + +def densitytensor_to_measurement_probabilities(densitytensor: jnp.ndarray, + qubit_inds: Sequence[int]) -> jnp.ndarray: + """ + Extract array of measurement probabilities given a densitytensor and some qubit indices to measure + (in the computational basis). + I.e. the ith element of the array corresponds to the probability of observing the bitstring + represented by the integer i on the measured qubits. + + Args: + densitytensor: Input densitytensor. + qubit_inds: Sequence of qubit indices to measure. + + Returns: + Normalised array of measurement probabilities. + """ + n_qubits = densitytensor.ndim // 2 + n_qubits_measured = len(qubit_inds) + qubit_inds_trace_out = [i for i in range(n_qubits) if i not in qubit_inds] + return jnp.diag(partial_trace(densitytensor, qubit_inds_trace_out).reshape(2 * n_qubits_measured, + 2 * n_qubits_measured)).real + + +def densitytensor_to_measured_densitytensor(densitytensor: jnp.ndarray, + qubit_inds: Sequence[int], + measurement: Union[int, jnp.ndarray]) -> jnp.ndarray: + """ + Returns the post-measurement densitytensor assuming that qubit_inds are measured + (in the computational basis) and the given measurement (integer or bitstring) is observed. + + Args: + densitytensor: Input densitytensor. + qubit_inds: Sequence of qubit indices to measure. + measurement: Observed integer or bitstring. + + Returns: + Post-measurement densitytensor (same shape as input densitytensor). + """ + measurement = jnp.array(measurement) + measured_int = bitstrings_to_integers(measurement) if measurement.ndim == 1 else measurement + + n_qubits = densitytensor.ndim // 2 + n_qubits_measured = len(qubit_inds) + qubit_inds_projector = jnp.diag(jnp.zeros(2 ** n_qubits_measured).at[measured_int].set(1)) \ + .reshape((2,) * 2 * n_qubits_measured) + unnorm_densitytensor = _kraus_single(densitytensor, qubit_inds_projector, qubit_inds) + norm_const = jnp.trace(unnorm_densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits)).real + return unnorm_densitytensor / norm_const diff --git a/qujax/observable.py b/qujax/observable.py deleted file mode 100644 index 376b0c1..0000000 --- a/qujax/observable.py +++ /dev/null @@ -1,223 +0,0 @@ -from __future__ import annotations -from typing import Sequence, Callable, Union, Optional - -from jax import numpy as jnp, random -from jax.lax import fori_loop - -from qujax import gates - - -def _statetensor_to_single_expectation_func(gate_tensor: jnp.ndarray, - qubit_inds: Sequence[int]) -> Callable[[jnp.ndarray], float]: - """ - Creates a function that maps statetensor to its expected value under the given gate unitary and qubit indices. - - Args: - gate_tensor: Gate unitary in tensor form. - qubit_inds: Sequence of integer qubit indices to apply gate to. - - Returns: - Function that takes statetensor and returns expected value (float). - """ - - def statetensor_to_single_expectation(statetensor: jnp.ndarray) -> float: - """ - Evaluates expected value of statetensor through gate. - - Args: - statetensor: Input statetensor. - - Returns: - Expected value (float). - """ - statetensor_new = jnp.tensordot(gate_tensor, statetensor, - axes=(list(range(-len(qubit_inds), 0)), qubit_inds)) - statetensor_new = jnp.moveaxis(statetensor_new, list(range(len(qubit_inds))), qubit_inds) - axes = tuple(range(statetensor.ndim)) - return jnp.tensordot(statetensor.conjugate(), statetensor_new, axes=(axes, axes)).real - - return statetensor_to_single_expectation - - -def get_statetensor_to_expectation_func(gate_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray], float]: - """ - Converts gate strings (or arrays), qubit indices and coefficients into a function that - converts statetensor into expected value. - - Args: - gate_seq_seq: Sequence of sequences of gates. - Each gate is either a tensor (jnp.ndarray) or a string corresponding to an array in qujax.gates. - E.g. [['Z', 'Z'], ['X']] - qubits_seq_seq: Sequence of sequences of integer qubit indices. - E.g. [[0,1], [2]] - coefficients: Sequence of float coefficients to scale the expected values. - - Returns: - Function that takes statetensor and returns expected value (float). - """ - - def get_gate_tensor(gate_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: - """ - Convert sequence of gate strings into single gate unitary (in tensor form). - - Args: - gate_seq: Sequence of gate strings or arrays. - - Returns: - Single gate unitary in tensor form (array). - - """ - single_gate_arrs = [gates.__dict__[gate] if isinstance(gate, str) else gate for gate in gate_seq] - single_gate_arrs = [gate_arr.reshape((2,) * int(jnp.log2(gate_arr.size))) - for gate_arr in single_gate_arrs] - full_gate_mat = single_gate_arrs[0] - for single_gate_matrix in single_gate_arrs[1:]: - full_gate_mat = jnp.kron(full_gate_mat, single_gate_matrix) - full_gate_mat = full_gate_mat.reshape((2,) * int(jnp.log2(full_gate_mat.size))) - return full_gate_mat - - apply_gate_funcs = [_statetensor_to_single_expectation_func(get_gate_tensor(gns), qi) - for gns, qi in zip(gate_seq_seq, qubits_seq_seq)] - - def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: - """ - Maps statetensor to expected value. - - Args: - statetensor: Input statetensor. - - Returns: - Expected value (float). - - """ - out = 0 - for coeff, f in zip(coefficients, apply_gate_funcs): - out += coeff * f(statetensor) - return out - - return statetensor_to_expectation_func - - -def get_statetensor_to_sampled_expectation_func(gate_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: - """ - Converts gate strings (or arrays), qubit indices and coefficients into a function that - converts statetensor into a sampled expectation value. - - Args: - gate_seq_seq: Sequence of sequences of gates. - Each gate is either a tensor (jnp.ndarray) or a string corresponding to an array in qujax.gates. - E.g. [['Z', 'Z'], ['X']] - qubits_seq_seq: Sequence of sequences of integer qubit indices. - E.g. [[0,1], [2]] - coefficients: Sequence of float coefficients to scale the expected values. - - Returns: - Function that takes statetensor, random key and integer number of shots - and returns sampled expected value (float). - """ - statetensor_to_expectation_func = get_statetensor_to_expectation_func(gate_seq_seq, qubits_seq_seq, coefficients) - - def statetensor_to_sampled_expectation_func(statetensor: jnp.ndarray, - random_key: random.PRNGKeyArray, - n_samps: int) -> float: - """ - Maps statetensor to sampled expected value. - - Args: - statetensor: Input statetensor. - random_key: JAX random key - n_samps: Number of samples contributing to sampled expectation. - - Returns: - Sampled expected value (float). - - """ - sampled_integers = sample_integers(random_key, statetensor, n_samps) - sampled_probs = fori_loop(0, n_samps, - lambda i, sv: sv.at[sampled_integers[i]].add(1), - jnp.zeros(statetensor.size)) - - sampled_probs /= n_samps - sampled_st = jnp.sqrt(sampled_probs).reshape(statetensor.shape) - return statetensor_to_expectation_func(sampled_st) - - return statetensor_to_sampled_expectation_func - - -def integers_to_bitstrings(integers: Union[int, jnp.ndarray], - nbits: int = None) -> jnp.ndarray: - """ - Convert integer or array of integers into their binary expansion(s). - - Args: - integers: Integer or array of integers to be converted. - nbits: Length of output binary expansion. - Defaults to smallest possible. - - Returns: - Array of binary expansion(s). - """ - integers = jnp.atleast_1d(integers) - if nbits is None: - nbits = (jnp.ceil(jnp.log2(jnp.maximum(integers.max(), 1)) + 1e-5)).astype(int) - - return jnp.squeeze(((integers[:, None] & (1 << jnp.arange(nbits - 1, -1, -1))) > 0).astype(int)) - - -def bitstrings_to_integers(bitstrings: jnp.ndarray) -> Union[int, jnp.ndarray]: - """ - Convert binary expansion(s) into integers. - - Args: - bitstrings: Bitstring array or array of bitstring arrays. - - Returns: - Array of integers. - """ - bitstrings = jnp.atleast_2d(bitstrings) - convarr = 2 ** jnp.arange(bitstrings.shape[-1] - 1, -1, -1) - return jnp.squeeze(bitstrings.dot(convarr)).astype(int) - - -def sample_integers(random_key: random.PRNGKeyArray, - statetensor: jnp.ndarray, - n_samps: Optional[int] = 1) -> jnp.ndarray: - """ - Generate random integer samples according to statetensor. - - Args: - random_key: JAX random key to seed samples. - statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). - n_samps: Number of samples to generate. Defaults to 1. - - Returns: - Array with sampled integers, shape=(n_samps,). - - """ - sv_probs = jnp.square(jnp.abs(statetensor.flatten())) - sampled_inds = random.choice(random_key, a=jnp.arange(statetensor.size), shape=(n_samps,), p=sv_probs) - return sampled_inds - - -def sample_bitstrings(random_key: random.PRNGKeyArray, - statetensor: jnp.ndarray, - n_samps: Optional[int] = 1) -> jnp.ndarray: - """ - Generate random bitstring samples according to statetensor. - - Args: - random_key: JAX random key to seed samples. - statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). - n_samps: Number of samples to generate. Defaults to 1. - - Returns: - Array with sampled bitstrings, shape=(n_samps, statetensor.ndim). - - """ - return integers_to_bitstrings(sample_integers(random_key, statetensor, n_samps), statetensor.ndim) diff --git a/qujax/circuit.py b/qujax/statetensor.py similarity index 60% rename from qujax/circuit.py rename to qujax/statetensor.py index 259c6d5..f0bda56 100644 --- a/qujax/circuit.py +++ b/qujax/statetensor.py @@ -1,35 +1,22 @@ from __future__ import annotations -from typing import Sequence, Union, Callable, Protocol +from typing import Sequence, Union, Callable from jax import numpy as jnp from qujax import gates -from qujax.circuit_tools import check_circuit - - -class CallableArrayAndOptionalArray(Protocol): - def __call__(self, params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: - ... - - -class CallableOptionalArray(Protocol): - def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: - ... - - -UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray] +from qujax.utils import check_circuit, _arrayify_inds, UnionCallableOptionalArray, gate_type def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray: """ Applies gate to statetensor and returns updated statetensor. - Gate is represented by a unitary matrix (i.e. not parameterised). + Gate is represented by a unitary matrix in tensor form. Args: statetensor: Input statetensor. gate_unitary: Unitary array representing gate must be in tensor form with shape (2,2,...). qubit_inds: Sequence of indices for gate to be applied to. - 2 * len(qubit_inds) is equal to the dimension of the gate unitary tensor. + Must have 2 * len(qubit_inds) = gate_unitary.ndim Returns: Updated statetensor. @@ -40,57 +27,61 @@ def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: return statetensor -def _to_gate_funcs(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]])\ - -> Sequence[Callable[[jnp.ndarray], jnp.ndarray]]: +def _to_gate_func(gate: gate_type) -> Callable[[jnp.ndarray], jnp.ndarray]: """ - Ensures all gate_seq elements are functions that map (possibly empty) parameters + Ensures a gate_seq element is a function that map (possibly empty) parameters to a unitary tensor. Args: - gate_seq: Sequence of gates. - Each element is either a string matching an array or function in qujax.gates, + gate: Either a string matching an array or function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) or a function taking parameters and returning gate unitary in tensor form. Returns: - Sequence of gate parameter to unitary functions - + Gate parameter to unitary functions """ + def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: return lambda: arr - gate_seq_callable = [] - for gate in gate_seq: - if isinstance(gate, str): - gate = gates.__dict__[gate] + if isinstance(gate, str): + gate = gates.__dict__[gate] - if callable(gate): - gate_func = gate - elif hasattr(gate, '__array__'): - gate_func = _array_to_callable(jnp.array(gate)) - else: - raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' - f'callable: {gate}') - gate_seq_callable.append(gate_func) + if callable(gate): + gate_func = gate + elif hasattr(gate, '__array__'): + gate_func = _array_to_callable(jnp.array(gate)) + else: + raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' + f'callable: {gate}') + return gate_func - return gate_seq_callable +def _gate_func_to_unitary(gate_func: Callable[[jnp.ndarray], jnp.ndarray], + qubit_inds: Sequence[int], + param_inds: jnp.ndarray, + params: jnp.ndarray) -> jnp.ndarray: + """ + Extract gate unitary. -def _arrayify_inds(param_inds_seq: Sequence[Sequence[int]]) -> Sequence[jnp.ndarray]: - param_inds_seq = [jnp.array(p) for p in param_inds_seq] - param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq] - return param_inds_seq + Args: + gate_func: Function that maps a (possibly empty) parameter array to a unitary tensor (array) + qubit_inds: Indices of qubits to apply gate to (only needed to ensure gate is in tensor form) + param_inds: Indices of full parameter to extract gate specific parameters + params: Full parameter vector + + Returns: + Array containing gate unitary in tensor form. + """ + gate_params = jnp.take(params, param_inds) + gate_unitary = gate_func(*gate_params) + gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form + return gate_unitary -def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]], +def get_params_to_statetensor_func(gate_seq: Sequence[gate_type], qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Union[None, Sequence[int]]], n_qubits: int = None) -> UnionCallableOptionalArray: """ Creates a function that maps circuit parameters to a statetensor. @@ -120,8 +111,8 @@ def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, if n_qubits is None: n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 - gate_seq_callable = _to_gate_funcs(gate_seq) - param_inds_seq = _arrayify_inds(param_inds_seq) + gate_seq_callable = [_to_gate_func(g) for g in gate_seq] + param_inds_array_seq = _arrayify_inds(param_inds_seq) def params_to_statetensor_func(params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: @@ -143,14 +134,13 @@ def params_to_statetensor_func(params: jnp.ndarray, else: statetensor = statetensor_in params = jnp.atleast_1d(params) - for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_seq): - gate_params = jnp.take(params, param_inds) - gate_unitary = gate_func(*gate_params) - gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form + for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_array_seq): + gate_unitary = _gate_func_to_unitary(gate_func, qubit_inds, param_inds, params) statetensor = apply_gate(statetensor, gate_unitary, qubit_inds) return statetensor - if all([pi.size == 0 for pi in param_inds_seq]): + non_parameterised = all([pi.size == 0 for pi in param_inds_array_seq]) + if non_parameterised: def no_params_to_statetensor_func(statetensor_in: jnp.ndarray = None) -> jnp.ndarray: """ Applies circuit (series of gates with no parameters) to a statetensor_in (default is |0>^N). diff --git a/qujax/statetensor_observable.py b/qujax/statetensor_observable.py new file mode 100644 index 0000000..ab1a08f --- /dev/null +++ b/qujax/statetensor_observable.py @@ -0,0 +1,171 @@ +from __future__ import annotations +from typing import Sequence, Callable, Union +from jax import numpy as jnp, random +from jax.lax import fori_loop + +from qujax.statetensor import apply_gate +from qujax.utils import check_hermitian, sample_integers, paulis + + +def statetensor_to_single_expectation(statetensor: jnp.ndarray, + hermitian: jnp.ndarray, + qubit_inds: Sequence[int]) -> float: + """ + Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). + + Args: + statetensor: Input statetensor. + hermitian: Hermitian array + must be in tensor form with shape (2,2,...). + qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. + Must have 2 * len(qubit_inds) == hermitian.ndim + + Returns: + Expected value (float). + """ + statetensor_new = apply_gate(statetensor, hermitian, qubit_inds) + axes = tuple(range(statetensor.ndim)) + return jnp.tensordot(statetensor.conjugate(), statetensor_new, axes=(axes, axes)).real + + +def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: + """ + Convert a sequence of observables represented by Pauli strings or Hermitian matrices in tensor form + into single array (in tensor form). + + Args: + hermitian_seq: Sequence of Hermitian strings or arrays. + + Returns: + Hermitian matrix in tensor form (array). + """ + for h in hermitian_seq: + check_hermitian(h) + + single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq] + single_arrs = [h_arr.reshape((2,) * int(jnp.log2(h_arr.size))) for h_arr in single_arrs] + + full_mat = single_arrs[0] + for single_matrix in single_arrs[1:]: + full_mat = jnp.kron(full_mat, single_matrix) + full_mat = full_mat.reshape((2,) * int(jnp.log2(full_mat.size))) + return full_mat + + +def _get_tensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray], + contraction_function: Callable) \ + -> Callable[[jnp.ndarray], float]: + """ + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a tensor into an expected value. + The contraction function performs the tensor contraction according to the type of tensor provided + (i.e. whether it is a statetensor or a densitytensor). + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + contraction_function: Function that performs the tensor contraction. + + Returns: + Function that takes tensor and returns expected value (float). + """ + + hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq] + + def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: + """ + Maps statetensor to expected value. + + Args: + statetensor: Input statetensor. + + Returns: + Expected value (float). + """ + out = 0 + for hermitian, qubit_inds, coeff in zip(hermitian_tensors, qubits_seq_seq, coefficients): + out += coeff * contraction_function(statetensor, hermitian, qubit_inds) + return out + + return statetensor_to_expectation_func + + +def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray], float]: + """ + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a statetensor into an expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) + or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes statetensor and returns expected value (float). + """ + + return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, + statetensor_to_single_expectation) + + +def get_statetensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: + """ + Converts strings (or arrays) representing Hermitian matrices, qubit indices and + coefficients into a function that converts a statetensor into a sampled expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes statetensor, random key and integer number of shots + and returns sampled expected value (float). + """ + statetensor_to_expectation_func = get_statetensor_to_expectation_func(hermitian_seq_seq, + qubits_seq_seq, + coefficients) + + def statetensor_to_sampled_expectation_func(statetensor: jnp.ndarray, + random_key: random.PRNGKeyArray, + n_samps: int) -> float: + """ + Maps statetensor to sampled expected value. + + Args: + statetensor: Input statetensor. + random_key: JAX random key + n_samps: Number of samples contributing to sampled expectation. + + Returns: + Sampled expected value (float). + """ + sampled_integers = sample_integers(random_key, statetensor, n_samps) + sampled_probs = fori_loop(0, n_samps, + lambda i, sv: sv.at[sampled_integers[i]].add(1), + jnp.zeros(statetensor.size)) + + sampled_probs /= n_samps + sampled_st = jnp.sqrt(sampled_probs).reshape(statetensor.shape) + return statetensor_to_expectation_func(sampled_st) + + return statetensor_to_sampled_expectation_func diff --git a/qujax/circuit_tools.py b/qujax/utils.py similarity index 56% rename from qujax/circuit_tools.py rename to qujax/utils.py index ed1ff0a..983320d 100644 --- a/qujax/circuit_tools.py +++ b/qujax/utils.py @@ -1,17 +1,41 @@ from __future__ import annotations -from typing import Sequence, Union, Callable, List, Tuple, Optional +from typing import Sequence, Union, Callable, List, Tuple, Optional, Protocol, Iterable import collections.abc from inspect import signature - -from jax import numpy as jnp +from jax import numpy as jnp, random from qujax import gates +paulis = {'X': gates.X, 'Y': gates.Y, 'Z': gates.Z} + + +class CallableArrayAndOptionalArray(Protocol): + def __call__(self, params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: + ... + + +class CallableOptionalArray(Protocol): + def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: + ... + + +UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray] +gate_type = Union[str, + jnp.ndarray, + Callable[[jnp.ndarray], jnp.ndarray], + Callable[[], jnp.ndarray]] +kraus_op_type = Union[gate_type, Iterable[gate_type]] + + +def check_unitary(gate: gate_type): + """ + Checks whether a matrix or tensor is unitary. -def check_unitary(gate: Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]): + Args: + gate: array containing potentially unitary string, array + or function (which will be evaluated with all arguments set to 0.1). + + """ if isinstance(gate, str): if gate in gates.__dict__: gate = gates.__dict__[gate] @@ -35,13 +59,49 @@ def check_unitary(gate: Union[str, raise TypeError(f'Gate not unitary: {gate}') -def check_circuit(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]], +def check_hermitian(hermitian: Union[str, jnp.ndarray]): + """ + Checks whether a matrix or tensor is Hermitian. + + Args: + hermitian: array containing potentially Hermitian matrix or tensor + + """ + if isinstance(hermitian, str): + if hermitian not in paulis: + raise TypeError(f'qujax only accepts {tuple(paulis.keys())} as Hermitian strings, received: {hermitian}') + else: + n_qubits = hermitian.ndim // 2 + hermitian_mat = hermitian.reshape(2 * n_qubits, 2 * n_qubits) + if not jnp.allclose(hermitian_mat, hermitian_mat.T.conj()): + raise TypeError(f'Array not Hermitian: {hermitian}') + + +def _arrayify_inds(param_inds_seq: Sequence[Union[None, Sequence[int]]]) -> Sequence[jnp.ndarray]: + """ + Ensure each element of param_inds_seq is an array (and therefore valid for jnp.take) + + Args: + param_inds_seq: Sequence of sequences representing parameter indices that gates are using, + i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter + (the float at position zero in the parameter vector/array), the second gate is not parameterised + and the third gates used the parameters at position five and two. + + Returns: + Sequence of arrays representing parameter indices. + """ + if param_inds_seq is None: + param_inds_seq = [None] + param_inds_seq = [jnp.array(p) for p in param_inds_seq] + param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq] + return param_inds_seq + + +def check_circuit(gate_seq: Sequence[kraus_op_type], qubit_inds_seq: Sequence[Sequence[int]], param_inds_seq: Sequence[Sequence[int]], - n_qubits: int = None): + n_qubits: int = None, + check_unitaries: bool = True): """ Basic checks that circuit arguments conform. @@ -50,11 +110,13 @@ def check_circuit(gate_seq: Sequence[Union[str, Each element is either a string matching an array or function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) or a function taking parameters and returning gate unitary in tensor form. + Or alternatively a sequence of the above representing Kraus operators. qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. param_inds_seq: Sequence of parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, the second gate is not parameterised and the third gates used the fifth and second parameters. n_qubits: Number of qubits, if fixed. + check_unitaries: boolean on whether to check if each gate represents a unitary matrix """ if not isinstance(gate_seq, collections.abc.Sequence): @@ -65,7 +127,8 @@ def check_circuit(gate_seq: Sequence[Union[str, raise TypeError('qubit_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]') if (not isinstance(param_inds_seq, collections.abc.Sequence)) or \ - (any([not (isinstance(p, collections.abc.Sequence) or hasattr(p, '__array__')) for p in param_inds_seq])): + (any([not (isinstance(p, collections.abc.Sequence) or hasattr(p, '__array__') or p is None) + for p in param_inds_seq])): raise TypeError('param_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]') if len(gate_seq) != len(qubit_inds_seq) or len(param_inds_seq) != len(param_inds_seq): @@ -75,15 +138,13 @@ def check_circuit(gate_seq: Sequence[Union[str, if n_qubits is not None and n_qubits < max([max(qi) for qi in qubit_inds_seq]) + 1: raise TypeError('n_qubits must be larger than largest qubit index in qubit_inds_seq') - for g in gate_seq: - check_unitary(g) + if check_unitaries: + for g in gate_seq: + check_unitary(g) -def _get_gate_str(gate_obj: Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]], - param_inds: Sequence[int]) -> str: +def _get_gate_str(gate_obj: kraus_op_type, + param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]) -> str: """ Maps single gate object to a four character string representation @@ -91,12 +152,18 @@ def _get_gate_str(gate_obj: Union[str, gate_obj: Either a string matching a function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape e.g. (2,2,2,...) ) or a function taking parameters (can be empty) and returning gate unitary in tensor form. - param_inds: Parameter indices that gates are using, i.e. gate uses 1st and 666th parameter. + Or alternatively, a sequence of Krause operators represented by strings, arrays or functions. + param_inds: Parameter indices that gates are using, i.e. gate uses 1st and 5th parameter. Returns: Four character string representation of the gate """ + if isinstance(gate_obj, (tuple, list)) or (hasattr(gate_obj, '__array__') and gate_obj.ndim % 2 == 1): + # Kraus operators + gate_obj = 'Kr' + param_inds = jnp.unique(jnp.concatenate(_arrayify_inds(param_inds), axis=0)) + if isinstance(gate_obj, str): gate_str = gate_obj elif hasattr(gate_obj, '__array__'): @@ -114,7 +181,10 @@ def _get_gate_str(gate_obj: Union[str, if hasattr(param_inds, 'tolist'): param_inds = param_inds.tolist() - if param_inds == [] or param_inds == [None]: + if isinstance(param_inds, tuple): + param_inds = list(param_inds) + + if param_inds == [] or param_inds == [None] or param_inds is None: if len(gate_str) > 7: gate_str = gate_str[:6] + '.' else: @@ -160,10 +230,7 @@ def extend_row(row: str, qubit_row: bool) -> str: return out_rows, [True] * len(rows) -def print_circuit(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]], +def print_circuit(gate_seq: Sequence[kraus_op_type], qubit_inds_seq: Sequence[Sequence[int]], param_inds_seq: Sequence[Sequence[int]], n_qubits: Optional[int] = None, @@ -180,6 +247,7 @@ def print_circuit(gate_seq: Sequence[Union[str, Each element is either a string matching an array or function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) or a function taking parameters and returning gate unitary in tensor form. + Or alternatively a sequence of the above representing Kraus operators. qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. param_inds_seq: Sequence of parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, @@ -195,7 +263,7 @@ def print_circuit(gate_seq: Sequence[Union[str, String representation of circuit """ - check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits, False) gate_ind_max = min(len(gate_seq) - 1, gate_ind_max) if gate_ind_max < gate_ind_min: @@ -247,3 +315,93 @@ def print_circuit(gate_seq: Sequence[Union[str, print(p) return rows + + +def integers_to_bitstrings(integers: Union[int, jnp.ndarray], + nbits: int = None) -> jnp.ndarray: + """ + Convert integer or array of integers into their binary expansion(s). + + Args: + integers: Integer or array of integers to be converted. + nbits: Length of output binary expansion. + Defaults to smallest possible. + + Returns: + Array of binary expansion(s). + """ + integers = jnp.atleast_1d(integers) + if nbits is None: + nbits = (jnp.ceil(jnp.log2(jnp.maximum(integers.max(), 1)) + 1e-5)).astype(int) + + return jnp.squeeze(((integers[:, None] & (1 << jnp.arange(nbits - 1, -1, -1))) > 0).astype(int)) + + +def bitstrings_to_integers(bitstrings: jnp.ndarray) -> Union[int, jnp.ndarray]: + """ + Convert binary expansion(s) into integers. + + Args: + bitstrings: Bitstring array or array of bitstring arrays. + + Returns: + Array of integers. + """ + bitstrings = jnp.atleast_2d(bitstrings) + convarr = 2 ** jnp.arange(bitstrings.shape[-1] - 1, -1, -1) + return jnp.squeeze(bitstrings.dot(convarr)).astype(int) + + +def sample_integers(random_key: random.PRNGKeyArray, + statetensor: jnp.ndarray, + n_samps: Optional[int] = 1) -> jnp.ndarray: + """ + Generate random integer samples according to statetensor. + + Args: + random_key: JAX random key to seed samples. + statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). + n_samps: Number of samples to generate. Defaults to 1. + + Returns: + Array with sampled integers, shape=(n_samps,). + + """ + sv_probs = jnp.square(jnp.abs(statetensor.flatten())) + sampled_inds = random.choice(random_key, a=jnp.arange(statetensor.size), shape=(n_samps,), p=sv_probs) + return sampled_inds + + +def sample_bitstrings(random_key: random.PRNGKeyArray, + statetensor: jnp.ndarray, + n_samps: Optional[int] = 1) -> jnp.ndarray: + """ + Generate random bitstring samples according to statetensor. + + Args: + random_key: JAX random key to seed samples. + statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). + n_samps: Number of samples to generate. Defaults to 1. + + Returns: + Array with sampled bitstrings, shape=(n_samps, statetensor.ndim). + + """ + return integers_to_bitstrings(sample_integers(random_key, statetensor, n_samps), statetensor.ndim) + + +def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: + """ + Computes a densitytensor representation of a pure quantum state + from its statetensor representaton + + Args: + statetensor: Input statetensor. + + Returns: + A densitytensor representing the quantum state. + """ + n_qubits = statetensor.ndim + st = statetensor + dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) + return dt diff --git a/qujax/version.py b/qujax/version.py index cd9b137..0404d81 100644 --- a/qujax/version.py +++ b/qujax/version.py @@ -1 +1 @@ -__version__ = '0.2.9' +__version__ = '0.3.0' diff --git a/tests/test_circuits.py b/tests/test_circuits.py index dd3f3c6..7dd4219 100644 --- a/tests/test_circuits.py +++ b/tests/test_circuits.py @@ -37,7 +37,7 @@ def test_H_redundant_qubits(): def test_CX_Rz_CY(): gates = ['H', 'H', 'H', 'CX', 'Rz', 'CY'] qubits = [[0], [1], [2], [0, 1], [1], [1, 2]] - param_inds = [[], [], [], [], [0], []] + param_inds = [[], [], [], None, [0], []] param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds) st = param_to_st(jnp.array(0.1)) diff --git a/tests/test_densitytensor.py b/tests/test_densitytensor.py new file mode 100644 index 0000000..8a79345 --- /dev/null +++ b/tests/test_densitytensor.py @@ -0,0 +1,251 @@ +from itertools import combinations +from jax import numpy as jnp, jit + +import qujax + + +def test_kraus_single(): + n_qubits = 3 + dim = 2 ** n_qubits + density_matrix = jnp.arange(dim ** 2).reshape(dim, dim) + density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) + kraus_operator = qujax.gates.Rx(0.2) + + qubit_inds = (1,) + + unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) + unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))) + check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T + + # qujax._kraus_single + qujax_kraus_dt = qujax._kraus_single(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) + + # qujax.kraus (but for a single array) + qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) + + +def test_kraus_single_2qubit(): + n_qubits = 4 + dim = 2 ** n_qubits + density_matrix = jnp.arange(dim ** 2).reshape(dim, dim) + density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) + kraus_operator_tensor = qujax.gates.ZZPhase(0.1) + kraus_operator = qujax.gates.ZZPhase(0.1).reshape(4, 4) + + qubit_inds = (1, 2) + + unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) + unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))) + check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T + + # qujax._kraus_single + qujax_kraus_dt = qujax._kraus_single(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))(density_tensor, + kraus_operator_tensor, + qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) + + # qujax.kraus (but for a single array) + qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator, qubit_inds) # check reshape kraus_operator correctly + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) + + +def test_kraus_multiple(): + n_qubits = 3 + dim = 2 ** n_qubits + density_matrix = jnp.arange(dim ** 2).reshape(dim, dim) + density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) + + kraus_operators = [0.25 * qujax.gates.H, 0.25 * qujax.gates.Rx(0.3), 0.5 * qujax.gates.Ry(0.1)] + + qubit_inds = (1,) + + unitary_matrices = [jnp.kron(jnp.eye(2 * qubit_inds[0]), ko) for ko in kraus_operators] + unitary_matrices = [jnp.kron(um, jnp.eye(2 * (n_qubits - qubit_inds[0] - 1))) for um in unitary_matrices] + + check_kraus_dm = jnp.zeros_like(density_matrix) + for um in unitary_matrices: + check_kraus_dm += um @ density_matrix @ um.conj().T + + qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operators, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(density_tensor, kraus_operators, qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) + + +def test_params_to_densitytensor_func(): + n_qubits = 2 + + gate_seq = ["Rx" for _ in range(n_qubits)] + qubit_inds_seq = [(i,) for i in range(n_qubits)] + param_inds_seq = [(i,) for i in range(n_qubits)] + + gate_seq += ["CZ" for _ in range(n_qubits - 1)] + qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] + param_inds_seq += [() for _ in range(n_qubits - 1)] + + params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_st = qujax.get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + params = jnp.arange(n_qubits) / 10. + + st = params_to_st(params) + dt_test = qujax.statetensor_to_densitytensor(st) + + dt = params_to_dt(params) + + assert jnp.allclose(dt, dt_test) + + jit_dt = jit(params_to_dt)(params) + assert jnp.allclose(jit_dt, dt_test) + + +def test_params_to_densitytensor_func_with_bit_flip(): + n_qubits = 2 + + gate_seq = ["Rx" for _ in range(n_qubits)] + qubit_inds_seq = [(i,) for i in range(n_qubits)] + param_inds_seq = [(i,) for i in range(n_qubits)] + + gate_seq += ["CZ" for _ in range(n_qubits - 1)] + qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] + param_inds_seq += [() for _ in range(n_qubits - 1)] + + params_to_pre_bf_st = qujax.get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + kraus_ops = [[0.3 * jnp.eye(2), 0.7 * qujax.gates.X]] + kraus_qubit_inds = [(0,)] + kraus_param_inds = [None] + + gate_seq += kraus_ops + qubit_inds_seq += kraus_qubit_inds + param_inds_seq += kraus_param_inds + + _ = qujax.print_circuit(gate_seq, qubit_inds_seq, param_inds_seq) + + params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + params = jnp.arange(n_qubits) / 10. + + pre_bf_st = params_to_pre_bf_st(params) + pre_bf_dt = (pre_bf_st.reshape(-1, 1) @ pre_bf_st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) + dt_test = qujax.kraus(pre_bf_dt, kraus_ops[0], kraus_qubit_inds[0]) + + dt = params_to_dt(params) + + assert jnp.allclose(dt, dt_test) + + jit_dt = jit(params_to_dt)(params) + assert jnp.allclose(jit_dt, dt_test) + + +def test_partial_trace_1(): + state1 = 1 / jnp.sqrt(2) * jnp.array([1., 1.]) + state2 = jnp.kron(state1, state1) + state3 = jnp.kron(state1, state2) + + dt1 = jnp.outer(state1, state1.conj()).reshape((2,) * 2) + dt2 = jnp.outer(state2, state2.conj()).reshape((2,) * 4) + dt3 = jnp.outer(state3, state3.conj()).reshape((2,) * 6) + + for i in range(3): + assert jnp.allclose(qujax.partial_trace(dt3, [i]), dt2) + + for i in combinations(range(3), 2): + assert jnp.allclose(qujax.partial_trace(dt3, i), dt1) + + +def test_partial_trace_2(): + n_qubits = 3 + + gate_seq = ["Rx" for _ in range(n_qubits)] + qubit_inds_seq = [(i,) for i in range(n_qubits)] + param_inds_seq = [(i,) for i in range(n_qubits)] + + gate_seq += ["CZ" for _ in range(n_qubits - 1)] + qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] + param_inds_seq += [() for _ in range(n_qubits - 1)] + + params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + params = jnp.arange(1, n_qubits + 1) / 10. + + dt = params_to_dt(params) + dt_discard_test = jnp.trace(dt, axis1=0, axis2=n_qubits) + dt_discard = qujax.partial_trace(dt, [0]) + + assert jnp.allclose(dt_discard, dt_discard_test) + + +def test_measure(): + n_qubits = 3 + + gate_seq = ["Rx" for _ in range(n_qubits)] + qubit_inds_seq = [(i,) for i in range(n_qubits)] + param_inds_seq = [(i,) for i in range(n_qubits)] + + gate_seq += ["CZ" for _ in range(n_qubits - 1)] + qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] + param_inds_seq += [() for _ in range(n_qubits - 1)] + + params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + params = jnp.arange(1, n_qubits + 1) / 10. + + dt = params_to_dt(params) + + qubit_inds = [0] + + all_probs = jnp.diag(dt.reshape(2 ** n_qubits, 2 ** n_qubits)).real + all_probs_marginalise \ + = all_probs.reshape((2,) * n_qubits).sum(axis=[i for i in range(n_qubits) if i not in qubit_inds]) + + probs = qujax.densitytensor_to_measurement_probabilities(dt, qubit_inds) + + assert jnp.isclose(probs.sum(), 1.) + assert jnp.isclose(all_probs.sum(), 1.) + assert jnp.allclose(probs, all_probs_marginalise) + + dm = dt.reshape(2 ** n_qubits, 2 ** n_qubits) + projector = jnp.array([[1, 0], [0, 0]]) + for _ in range(n_qubits - 1): + projector = jnp.kron(projector, jnp.eye(2)) + measured_dm = projector @ dm @ projector.T.conj() + measured_dm /= jnp.trace(projector.T.conj() @ projector @ dm) + measured_dt_true = measured_dm.reshape((2,) * 2 * n_qubits) + + measured_dt = qujax.densitytensor_to_measured_densitytensor(dt, qubit_inds, 0) + measured_dt_bits = qujax.densitytensor_to_measured_densitytensor(dt, qubit_inds, (0,)*n_qubits) + assert jnp.allclose(measured_dt_true, measured_dt) + assert jnp.allclose(measured_dt_true, measured_dt_bits) diff --git a/tests/test_expectations.py b/tests/test_expectations.py index 8b4e707..0f168a5 100644 --- a/tests/test_expectations.py +++ b/tests/test_expectations.py @@ -1,8 +1,34 @@ -from jax import numpy as jnp, jit, grad, random +from jax import numpy as jnp, jit, grad, random, config import qujax +def test_pauli_hermitian(): + for p_str in ('X', 'Y', 'Z'): + qujax.check_hermitian(p_str) + qujax.check_hermitian(qujax.gates.__dict__[p_str]) + + +def test_single_expectation(): + Z = qujax.gates.Z + + st1 = jnp.zeros((2, 2, 2)) + st2 = jnp.zeros((2, 2, 2)) + st1 = st1.at[(0, 0, 0)].set(1.) + st2 = st2.at[(1, 0, 0)].set(1.) + dt1 = qujax.statetensor_to_densitytensor(st1) + dt2 = qujax.statetensor_to_densitytensor(st2) + ZZ = jnp.kron(Z, Z).reshape(2, 2, 2, 2) + + est1 = qujax.statetensor_to_single_expectation(dt1, ZZ, [0, 1]) + est2 = qujax.statetensor_to_single_expectation(dt2, ZZ, [0, 1]) + edt1 = qujax.densitytensor_to_single_expectation(dt1, ZZ, [0, 1]) + edt2 = qujax.densitytensor_to_single_expectation(dt2, ZZ, [0, 1]) + + assert est1.item() == edt1.item() == 1 + assert est2.item() == edt2.item() == -1 + + def test_bitstring_expectation(): n_qubits = 4 @@ -31,58 +57,104 @@ def st_to_expectation(statetensor): param_to_expectation = lambda p: st_to_expectation(param_to_st(p)) + def brute_force_param_to_exp(p): + sv = param_to_st(p).flatten() + return jnp.dot(sv, jnp.diag(costs) @ sv.conj()).real + + true_expectation = brute_force_param_to_exp(params) + expectation = param_to_expectation(params) expectation_jit = jit(param_to_expectation)(params) assert expectation.shape == () - assert expectation.dtype == 'float32' - assert jnp.abs(-0.97042876 - expectation) < 1e-5 - assert jnp.abs(-0.97042876 - expectation_jit) < 1e-5 + assert expectation.dtype.name[:5] == 'float' + assert jnp.isclose(true_expectation, expectation) + assert jnp.isclose(true_expectation, expectation_jit) + true_expectation_grad = grad(brute_force_param_to_exp)(params) expectation_grad = grad(param_to_expectation)(params) expectation_grad_jit = jit(grad(param_to_expectation))(params) - true_expectation_grad = jnp.array([5.1673526e-01, 1.2618620e+00, 5.1392573e-01, - 1.5056899e+00, 4.3226164e-02, 3.4227133e-02, - 8.1762001e-02, 7.7345759e-01, 5.1567715e-01, - -3.1131029e-01, -1.7132770e-01, -6.6244489e-01, - 9.3626760e-08, -4.6813380e-08, -2.3406690e-08, - -9.3626760e-08]) - assert expectation_grad.shape == (n_params,) - assert expectation_grad.dtype == 'float32' - assert jnp.all(jnp.abs(expectation_grad - true_expectation_grad) < 1e-5) - assert jnp.all(jnp.abs(expectation_grad_jit - true_expectation_grad) < 1e-5) + assert expectation_grad.dtype.name[:5] == 'float' + assert jnp.allclose(true_expectation_grad, expectation_grad, atol=1e-5) + assert jnp.allclose(true_expectation_grad, expectation_grad_jit, atol=1e-5) -def test_ZZ_X(): - n_qubits = 5 +def test_ZZ_Y(): + config.update("jax_enable_x64", True) # Run this test with 64 bit precision - gate_str_seq_seq = [['Z', 'Z']] * (n_qubits - 1) + [['X']] * n_qubits - coefs = random.normal(random.PRNGKey(0), shape=(len(gate_str_seq_seq),)) + n_qubits = 4 + + hermitian_str_seq_seq = [['Z', 'Z']] * (n_qubits - 1) + [['Y']] * n_qubits + coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),)) qubit_inds_seq = [[i, i + 1] for i in range(n_qubits - 1)] + [[i] for i in range(n_qubits)] - st_to_exp = qujax.get_statetensor_to_expectation_func(gate_str_seq_seq, + st_to_exp = qujax.get_statetensor_to_expectation_func(hermitian_str_seq_seq, + qubit_inds_seq, + coefs) + dt_to_exp = qujax.get_statetensor_to_expectation_func(hermitian_str_seq_seq, qubit_inds_seq, coefs) state = random.uniform(random.PRNGKey(0), shape=(2 ** n_qubits,)) * 2 state /= jnp.linalg.norm(state) st_in = state.reshape((2,) * n_qubits) - - jax_exp = st_to_exp(st_in) - jax_exp_jit = jit(st_to_exp)(st_in) - - assert jnp.abs(-0.23738188 - jax_exp) < 1e-5 - assert jnp.abs(-0.23738188 - jax_exp_jit) < 1e-5 - - st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(gate_str_seq_seq, + dt_in = qujax.statetensor_to_densitytensor(st_in) + + def big_hermitian_matrix(hermitian_str_seq, qubit_inds): + qubit_arrs = [getattr(qujax.gates, s) for s in hermitian_str_seq] + hermitian_arrs = [] + j = 0 + for i in range(n_qubits): + if i in qubit_inds: + hermitian_arrs.append(qubit_arrs[j]) + j += 1 + else: + hermitian_arrs.append(jnp.eye(2)) + + big_h = hermitian_arrs[0] + for k in range(1, n_qubits): + big_h = jnp.kron(big_h, hermitian_arrs[k]) + return big_h + + sum_big_hs = jnp.zeros((2 ** n_qubits, 2 ** n_qubits), dtype='complex') + for i in range(len(hermitian_str_seq_seq)): + sum_big_hs += coefs[i] * big_hermitian_matrix(hermitian_str_seq_seq[i], qubit_inds_seq[i]) + + assert jnp.allclose(sum_big_hs, sum_big_hs.conj().T) + + sv = st_in.flatten() + true_exp = jnp.dot(sv, sum_big_hs @ sv.conj()).real + + qujax_exp = st_to_exp(st_in) + qujax_dt_exp = dt_to_exp(dt_in) + qujax_exp_jit = jit(st_to_exp)(st_in) + qujax_dt_exp_jit = jit(dt_to_exp)(dt_in) + + assert jnp.array(qujax_exp).shape == () + assert jnp.array(qujax_exp).dtype.name[:5] == 'float' + assert jnp.isclose(true_exp, qujax_exp) + assert jnp.isclose(true_exp, qujax_dt_exp) + assert jnp.isclose(true_exp, qujax_exp_jit) + assert jnp.isclose(true_exp, qujax_dt_exp_jit) + + st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(hermitian_str_seq_seq, + qubit_inds_seq, + coefs) + dt_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(hermitian_str_seq_seq, qubit_inds_seq, coefs) - jax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 10000) - jax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 10000) - assert jnp.abs(-0.23738188 - jax_samp_exp) < 1e-2 - assert jnp.abs(-0.23738188 - jax_samp_exp_jit) < 1e-2 + qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000) + qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) + qujax_samp_exp_dt = dt_to_samp_exp(st_in, random.PRNGKey(1), 1000000) + qujax_samp_exp_dt_jit = jit(dt_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) + assert jnp.array(qujax_samp_exp).shape == () + assert jnp.array(qujax_samp_exp).dtype.name[:5] == 'float' + assert jnp.isclose(true_exp, qujax_samp_exp, rtol=1e-2) + assert jnp.isclose(true_exp, qujax_samp_exp_jit, rtol=1e-2) + assert jnp.isclose(true_exp, qujax_samp_exp_dt, rtol=1e-2) + assert jnp.isclose(true_exp, qujax_samp_exp_dt_jit, rtol=1e-2) def test_sampling(): diff --git a/tests/test_gates.py b/tests/test_gates.py index 9f41f59..3d29f0b 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -1,5 +1,4 @@ -from qujax import gates -from qujax.circuit_tools import check_unitary +from qujax import gates, check_unitary def test_gates():