Skip to content

Commit

Permalink
Merge pull request #52 from CQCL/develop
Browse files Browse the repository at this point in the history
Add densitytensor support
  • Loading branch information
SamDuffield authored Nov 7, 2022
2 parents dc18c4b + 7c7d305 commit f47cf3d
Show file tree
Hide file tree
Showing 24 changed files with 1,212 additions and 360 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# qujax

Represent a (parameterised) quantum circuit as a pure [JAX](https://github.com/google/jax) function that
takes as input any parameters of the circuit and outputs a _statetensor_. The statetensor encodes all $2^N$ amplitudes of the quantum state and can then be used
downstream for exact expectations, gradients or sampling.
takes as input any parameters of the circuit and outputs a _statetensor_. The statetensor encodes all $2^N$ amplitudes
of the quantum state and can then be used downstream for exact expectations, gradients or sampling.

A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support for GPUs/TPUs.
qujax also supports densitytensor simulations. A densitytensor is a tensor representation of the density matrix and allows for mixed states and generic Kraus operators.

A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support
for GPUs/TPUs.

Some useful links:
- [Documentation](https://cqcl.github.io/qujax/api/)
Expand Down
4 changes: 3 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = ['sphinx.ext.autodoc', 'sphinx_rtd_theme', 'sphinx.ext.napoleon']
extensions = ['sphinx.ext.autodoc', 'sphinx_rtd_theme', 'sphinx.ext.napoleon', 'sphinx.ext.mathjax']

templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
Expand All @@ -38,3 +38,5 @@
'Callable[[Optional[ndarray]], ndarray]]'

}

latex_engine = 'pdflatex'
14 changes: 14 additions & 0 deletions docs/densitytensor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
densitytensor
=======================

.. toctree::

densitytensor/kraus
densitytensor/get_params_to_densitytensor_func
densitytensor/partial_trace
densitytensor/get_densitytensor_to_expectation_func
densitytensor/get_densitytensor_to_sampled_expectation_func
densitytensor/densitytensor_to_measurement_probabilities
densitytensor/densitytensor_to_measured_densitytensor
densitytensor/statetensor_to_densitytensor

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
densitytensor_to_measured_densitytensor
==============================================

.. autofunction:: qujax.densitytensor_to_measured_densitytensor

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
densitytensor_to_measurement_probabilities
==============================================

.. autofunction:: qujax.densitytensor_to_measurement_probabilities

5 changes: 5 additions & 0 deletions docs/densitytensor/get_densitytensor_to_expectation_func.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
get_densitytensor_to_expectation_func
=======================================

.. autofunction:: qujax.get_densitytensor_to_expectation_func

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
get_densitytensor_to_sampled_expectation_func
================================================

.. autofunction:: qujax.get_densitytensor_to_sampled_expectation_func

5 changes: 5 additions & 0 deletions docs/densitytensor/get_params_to_densitytensor_func.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
get_params_to_densitytensor_func
===================================

.. autofunction:: qujax.get_params_to_densitytensor_func

5 changes: 5 additions & 0 deletions docs/densitytensor/kraus.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
kraus
=======================

.. autofunction:: qujax.kraus

5 changes: 5 additions & 0 deletions docs/densitytensor/partial_trace.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
partial_trace
===============================

.. autofunction:: qujax.partial_trace

5 changes: 5 additions & 0 deletions docs/densitytensor/statetensor_to_densitytensor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
statetensor_to_densitytensor
===============================

.. autofunction:: qujax.statetensor_to_densitytensor

1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Docs
sample_bitstrings
check_circuit
print_circuit
densitytensor
gates <https://github.com/CQCL/qujax/blob/main/qujax/gates.py>


Expand Down
47 changes: 32 additions & 15 deletions qujax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,39 @@

from qujax import gates

from qujax.circuit import UnionCallableOptionalArray
from qujax.circuit import apply_gate
from qujax.circuit import get_params_to_statetensor_func
from qujax.statetensor import apply_gate
from qujax.statetensor import get_params_to_statetensor_func

from qujax.observable import get_statetensor_to_expectation_func
from qujax.observable import get_statetensor_to_sampled_expectation_func
from qujax.observable import integers_to_bitstrings
from qujax.observable import bitstrings_to_integers
from qujax.observable import sample_integers
from qujax.observable import sample_bitstrings
from qujax.statetensor_observable import statetensor_to_single_expectation
from qujax.statetensor_observable import get_statetensor_to_expectation_func
from qujax.statetensor_observable import get_statetensor_to_sampled_expectation_func

from qujax.circuit_tools import check_unitary
from qujax.circuit_tools import check_circuit
from qujax.circuit_tools import print_circuit
from qujax.densitytensor import _kraus_single
from qujax.densitytensor import kraus
from qujax.densitytensor import get_params_to_densitytensor_func
from qujax.densitytensor import partial_trace

from qujax.densitytensor_observable import densitytensor_to_single_expectation
from qujax.densitytensor_observable import get_densitytensor_to_expectation_func
from qujax.densitytensor_observable import get_densitytensor_to_sampled_expectation_func
from qujax.densitytensor_observable import densitytensor_to_measurement_probabilities
from qujax.densitytensor_observable import densitytensor_to_measured_densitytensor

from qujax.utils import UnionCallableOptionalArray
from qujax.utils import check_unitary
from qujax.utils import check_hermitian
from qujax.utils import check_circuit
from qujax.utils import print_circuit
from qujax.utils import integers_to_bitstrings
from qujax.utils import bitstrings_to_integers
from qujax.utils import sample_integers
from qujax.utils import sample_bitstrings
from qujax.utils import statetensor_to_densitytensor

del version
del circuit
del observable
del circuit_tools
del statetensor
del statetensor_observable
del densitytensor
del densitytensor_observable
del utils

200 changes: 200 additions & 0 deletions qujax/densitytensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from __future__ import annotations
from typing import Sequence, Union, Callable, Iterable, Tuple
from jax import numpy as jnp
from jax.lax import scan

from qujax.statetensor import apply_gate, UnionCallableOptionalArray
from qujax.statetensor import _to_gate_func, _arrayify_inds, _gate_func_to_unitary
from qujax.utils import check_circuit, kraus_op_type


def _kraus_single(densitytensor: jnp.ndarray,
array: jnp.ndarray,
qubit_inds: Sequence[int]) -> jnp.ndarray:
r"""
Performs single Kraus operation
.. math::
\rho_\text{out} = B \rho_\text{in} B^{\dagger}
Args:
densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits
array: Array containing the Kraus operator (in tensor form).
qubit_inds: Sequence of qubit indices on which to apply the Kraus operation.
Returns:
Updated density matrix.
"""
n_qubits = densitytensor.ndim // 2
densitytensor = apply_gate(densitytensor, array, qubit_inds)
densitytensor = apply_gate(densitytensor, array.conj(), [n_qubits + i for i in qubit_inds])
return densitytensor


def kraus(densitytensor: jnp.ndarray,
arrays: Iterable[jnp.ndarray],
qubit_inds: Sequence[int]) -> jnp.ndarray:
r"""
Performs Kraus operation.
.. math::
\rho_\text{out} = \sum_i B_i \rho_\text{in} B_i^{\dagger}
Args:
densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits
arrays: Sequence of arrays containing the Kraus operators (in tensor form).
qubit_inds: Sequence of qubit indices on which to apply the Kraus operation.
Returns:
Updated density matrix.
"""
arrays = jnp.array(arrays)
if arrays.ndim % 2 == 0:
arrays = arrays[jnp.newaxis]
# ensure first dimensions indexes different kraus operators
arrays = arrays.reshape((arrays.shape[0],) + (2,) * 2 * len(qubit_inds))

new_densitytensor, _ = scan(lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None),
init=jnp.zeros_like(densitytensor) * 0.j, xs=arrays)
# i.e. new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))(densitytensor, arrays, qubit_inds).sum(0)
return new_densitytensor


def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type,
param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]) \
-> Tuple[Sequence[Callable[[jnp.ndarray], jnp.ndarray]],
Sequence[jnp.ndarray]]:
"""
Ensures Kraus operators are a sequence of functions that map (possibly empty) parameters to tensors
and that each element of param_inds_seq is a sequence of arrays that correspond to the parameter indices
of each Kraus operator.
Args:
kraus_op: Either a normal gate_type or a sequence of gate_types representing Kraus operators.
param_inds: If kraus_op is a normal gate_type then a sequence of parameter indices,
if kraus_op is a sequence of Kraus operators then a sequence of sequences of parameter indices
Returns:
Tuple containing sequence of functions mapping to Kraus operators
and sequence of arrays with parameter indices
"""
if param_inds is None:
param_inds = [None for _ in kraus_op]

if isinstance(kraus_op, (list, tuple)):
kraus_op_funcs = [_to_gate_func(ko) for ko in kraus_op]
else:
kraus_op_funcs = [_to_gate_func(kraus_op)]
param_inds = [param_inds]
return kraus_op_funcs, _arrayify_inds(param_inds)


def partial_trace(densitytensor: jnp.ndarray,
indices_to_trace: Sequence[int]) -> jnp.ndarray:
"""
Traces out (discards) specified qubits, resulting in a densitytensor
representing the mixed quantum state on the remaining qubits.
Args:
densitytensor: Input densitytensor.
indices_to_trace: Indices of qubits to trace out/discard.
Returns:
Resulting densitytensor on remaining qubits.
"""
n_qubits = densitytensor.ndim // 2
einsum_indices = list(range(densitytensor.ndim))
for i in indices_to_trace:
einsum_indices[i + n_qubits] = einsum_indices[i]
densitytensor = jnp.einsum(densitytensor, einsum_indices)
return densitytensor


def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[kraus_op_type],
qubit_inds_seq: Sequence[Sequence[int]],
param_inds_seq: Sequence[Union[None, Sequence[int], Sequence[Sequence[int]]]],
n_qubits: int = None) -> UnionCallableOptionalArray:
"""
Creates a function that maps circuit parameters to a density tensor (a density matrix in tensor form).
densitytensor = densitymatrix.reshape((2,) * 2 * n_qubits)
densitymatrix = densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits)
Args:
kraus_ops_seq: Sequence of gates.
Each element is either a string matching a unitary array or function in qujax.gates,
a custom unitary array or a custom function taking parameters and returning a unitary array.
Unitary arrays will be reshaped into tensor form (2, 2,...)
qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on.
i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit,
the second gate is a two qubit gate acting on the zeroth and first qubit etc.
param_inds_seq: Sequence of sequences representing parameter indices that gates are using,
i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter
(the float at position zero in the parameter vector/array), the second gate is not parameterised
and the third gates used the parameters at position five and two.
n_qubits: Number of qubits, if fixed.
Returns:
Function which maps parameters (and optional densitytensor_in) to a densitytensor.
If no parameters are found then the function only takes optional densitytensor_in.
"""

check_circuit(kraus_ops_seq, qubit_inds_seq, param_inds_seq, n_qubits, False)

if n_qubits is None:
n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1

kraus_ops_seq_callable_and_param_inds = [_to_kraus_operator_seq_funcs(ko, param_inds)
for ko, param_inds in zip(kraus_ops_seq, param_inds_seq)]
kraus_ops_seq_callable = [ko_pi[0] for ko_pi in kraus_ops_seq_callable_and_param_inds]
param_inds_array_seq = [ko_pi[1] for ko_pi in kraus_ops_seq_callable_and_param_inds]

def params_to_densitytensor_func(params: jnp.ndarray,
densitytensor_in: jnp.ndarray = None) -> jnp.ndarray:
"""
Applies parameterised circuit (series of gates) to a densitytensor_in (default is |0>^N <0|^N).
Args:
params: Parameters of the circuit.
densitytensor_in: Optional. Input densitytensor.
Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in [0]*(2*N) index).
Returns:
Updated densitytensor.
"""
if densitytensor_in is None:
densitytensor = jnp.zeros((2,) * 2 * n_qubits)
densitytensor = densitytensor.at[(0,) * 2 * n_qubits].set(1.)
else:
densitytensor = densitytensor_in
params = jnp.atleast_1d(params)
for gate_func_single_seq, qubit_inds, param_inds_single_seq in zip(kraus_ops_seq_callable, qubit_inds_seq,
param_inds_array_seq):
kraus_operators = [_gate_func_to_unitary(gf, qubit_inds, pi, params)
for gf, pi in zip(gate_func_single_seq, param_inds_single_seq)]
densitytensor = kraus(densitytensor, kraus_operators, qubit_inds)
return densitytensor

non_parameterised = all([all([pi.size == 0 for pi in pi_seq]) for pi_seq in param_inds_array_seq])
if non_parameterised:
def no_params_to_densitytensor_func(densitytensor_in: jnp.ndarray = None) -> jnp.ndarray:
"""
Applies circuit (series of gates with no parameters) to a densitytensor_in (default is |0>^N <0|^N).
Args:
densitytensor_in: Optional. Input densitytensor.
Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in [0]*(2*N) index).
Returns:
Updated densitytensor.
"""
return params_to_densitytensor_func(jnp.array([]), densitytensor_in)

return no_params_to_densitytensor_func

return params_to_densitytensor_func
Loading

0 comments on commit f47cf3d

Please sign in to comment.