Skip to content

Commit

Permalink
Update shadows.py and test_shadows.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PeilinZHENG committed Aug 2, 2023
1 parent 1b8715e commit f2185bd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 37 deletions.
60 changes: 33 additions & 27 deletions tensorcircuit/shadows.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""
Classical Shadows functions
"""
from typing import Any, Union, Optional, Sequence
from typing import Any, Union, Optional, Sequence, Tuple, List
from string import ascii_letters as ABC
import numpy as np
from numpy import ndarray

from .cons import backend, dtypestr, rdtypestr
from .circuit import Circuit
Expand All @@ -12,8 +13,8 @@


def shadow_bound(
observables: Union[np.ndarray, Sequence[int]], epsilon: float, delta: float = 0.01
):
observables: Union[Tensor, Sequence[int]], epsilon: float, delta: float = 0.01
) -> Tuple[int, int]:
r"""Calculate the shadow bound of the Pauli observables, please refer to the Theorem S1 and Lemma S3 in Huang, H.-Y., R. Kueng, and J. Preskill, 2020, Nat. Phys. 16, 1050.
:param observables: shape = (nq,) or (M, nq), where nq is the number of qubits, M is the number of observables
Expand All @@ -28,12 +29,12 @@ def shadow_bound(
:return k: Number of equal parts to split the shadow snapshot states to compute the median of means. k=1 (default) corresponds to simply taking the mean over all shadow snapshot states.
:rtype: int
"""
observables = np.sign(np.asarray(observables))
if len(observables.shape) == 1:
observables = observables[None, :]
M = observables.shape[0]
count = np.sign(np.asarray(observables))
if len(count.shape) == 1:
count = count[None, :]
M = count.shape[0]
k = np.ceil(2 * np.log(2 * M / delta))
max_length = np.max(np.sum(observables, axis=1))
max_length = np.max(np.sum(count, axis=1))
N = np.ceil((34 / epsilon**2) * 3**max_length)
return int(N * k), int(k)

Expand All @@ -42,8 +43,9 @@ def shadow_snapshots(
psi: Tensor,
pauli_strings: Tensor,
status: Optional[Tensor] = None,
sub: Optional[Sequence[int]] = None,
measurement_only: bool = False,
):
) -> Tensor:
r"""To generate the shadow snapshots from given pauli string observables on $|\psi\rangle$
:param psi: shape = (2 ** nq, 2 ** nq), where nq is the number of qubits
Expand All @@ -52,6 +54,8 @@ def shadow_snapshots(
:type: Tensor
:param status: shape = None or (ns, repeat), where repeat is the times to measure on one pauli string
:type: Optional[Tensor]
:param sub: qubit indices of subsystem
:type: Optional[Sequence[int]]
:param measurement_only: return snapshots (True) or snapshot states (false), default=False
:type: bool
Expand Down Expand Up @@ -82,10 +86,10 @@ def shadow_snapshots(
dtype=rdtypestr,
) # (3, 3)

def proj_measure(pauli_string, st):
def proj_measure(pauli_string: Tensor, st: Tensor) -> Tensor:
c_ = Circuit(nq, inputs=psi)
for i in range(nq):
c_.R(
c_.r(
i,
theta=backend.gather1d(
backend.gather1d(angles, backend.gather1d(pauli_string, i)), 0
Expand All @@ -102,14 +106,14 @@ def proj_measure(pauli_string, st):
vpm = backend.vmap(proj_measure, vectorized_argnums=(0, 1))
snapshots = vpm(pauli_strings, status) # (ns, repeat, nq)
if measurement_only:
return snapshots
return snapshots if sub is None else slice_sub(snapshots, sub)
else:
return local_snapshot_states(snapshots, pauli_strings + 1)
return local_snapshot_states(snapshots, pauli_strings + 1, sub)


def local_snapshot_states(
snapshots: Tensor, pauli_strings: Tensor, sub: Optional[Sequence[int]] = None
):
) -> Tensor:
r"""To generate the local snapshots states from snapshots and pauli strings
:param snapshots: shape = (ns, repeat, nq)
Expand Down Expand Up @@ -143,7 +147,7 @@ def local_snapshot_states(
)
pauli_dm = backend.stack((X_dm, Y_dm, Z_dm), axis=0) # (3, 2, 2, 2)

def dm(p, s):
def dm(p: Tensor, s: Tensor) -> Tensor:
return backend.gather1d(backend.gather1d(pauli_dm, p), s)

v = backend.vmap(dm, vectorized_argnums=(0, 1))
Expand All @@ -160,7 +164,7 @@ def global_shadow_state(
snapshots: Tensor,
pauli_strings: Optional[Tensor] = None,
sub: Optional[Sequence[int]] = None,
):
) -> Tensor:
r"""To generate the global shadow state from local snapshot states or snapshots and pauli strings
:param snapshots: shape = (ns, repeat, nq, 2, 2) or (ns, repeat, nq)
Expand Down Expand Up @@ -189,7 +193,7 @@ def global_shadow_state(

nq = lss_states.shape[2]

def tensor_prod(dms):
def tensor_prod(dms: Tensor) -> Tensor:
res = backend.gather1d(dms, 0)
for i in range(1, nq):
res = backend.kron(res, backend.gather1d(dms, i))
Expand All @@ -209,7 +213,7 @@ def expection_ps_shadow(
z: Optional[Sequence[int]] = None,
ps: Optional[Sequence[int]] = None,
k: int = 1,
):
) -> List[Tensor]:
r"""To calculate the expectation value of an observable on shadow snapshot states
:param snapshots: shape = (ns, repeat, nq, 2, 2) or (ns, repeat, nq)
Expand Down Expand Up @@ -275,12 +279,12 @@ def expection_ps_shadow(
)
) # (4, 2, 2)

def trace_paulis_prod(dm, idx):
def trace_paulis_prod(dm: Tensor, idx: Tensor) -> Tensor:
return backend.real(backend.trace(backend.gather1d(paulis, idx) @ dm))

v = backend.vmap(trace_paulis_prod, vectorized_argnums=(0, 1)) # (nq,)

def prod(dm):
def prod(dm: Tensor) -> Tensor:
return backend.shape_prod(v(dm, ps))

vv = backend.vmap(prod, vectorized_argnums=0) # (ns,)
Expand All @@ -294,7 +298,7 @@ def entropy_shadow(
pauli_strings: Optional[Tensor] = None,
sub: Optional[Sequence[int]] = None,
alpha: int = 2,
):
) -> Tensor:
r"""To calculate the Renyi entropy of a subsystem from shadow state or shadow snapshot states
:param snapshots: shape = (ns, repeat, nq, 2, 2) or (ns, repeat, nq)
Expand Down Expand Up @@ -328,7 +332,7 @@ def global_shadow_state1(
snapshots: Tensor,
pauli_strings: Optional[Tensor] = None,
sub: Optional[Sequence[int]] = None,
):
) -> Tensor:
r"""To generate the global snapshots states from local snapshot states or snapshots and pauli strings
:param snapshots: shape = (ns, repeat, nq, 2, 2) or (ns, repeat, nq)
Expand Down Expand Up @@ -375,7 +379,7 @@ def global_shadow_state2(
snapshots: Tensor,
pauli_strings: Optional[Tensor] = None,
sub: Optional[Sequence[int]] = None,
):
) -> Tensor:
r"""To generate the global snapshots states from local snapshot states or snapshots and pauli strings
:param snapshots: shape = (ns, repeat, nq, 2, 2) or (ns, repeat, nq)
Expand Down Expand Up @@ -406,7 +410,7 @@ def global_shadow_state2(
old_indices = [f"{ABC[2 * i: 2 + 2 * i]}" for i in range(nq)]
new_indices = f"{ABC[0:2 * nq:2]}{ABC[1:2 * nq:2]}"

def tensor_prod(dms):
def tensor_prod(dms: Tensor) -> Tensor:
return backend.reshape(
backend.einsum(
f'{",".join(old_indices)}->{new_indices}', *dms, optimize=True
Expand All @@ -420,19 +424,21 @@ def tensor_prod(dms):
return backend.mean(gss_states, axis=(0, 1))


def slice_sub(entirety: Tensor, sub: Sequence[int]):
def slice_sub(entirety: Tensor, sub: Sequence[int]) -> Tensor:
r"""To slice off the subsystem
:param entirety: shape = (ns, repeat, nq, 2, 2)
:param entirety: shape = (ns, repeat, nq, 2, 2) or (ns, repeat, nq)
:type: Tensor
:param sub: qubit indices of subsystem
:type: Sequence[int]
:return subsystem: shape = (ns, repeat, nq_sub, 2, 2)
:rtype: Tensor
"""
if len(entirety.shape) < 3:
entirety = entirety[:, None, :]

def slc(x, idx):
def slc(x: Tensor, idx: Tensor) -> Tensor:
return backend.gather1d(x, idx)

v = backend.vmap(slc, vectorized_argnums=(1,))
Expand Down
19 changes: 9 additions & 10 deletions tests/test_shadows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
shadow_snapshots,
local_snapshot_states,
global_shadow_state,
global_shadow_state1,
global_shadow_state2,
entropy_shadow,
expection_ps_shadow,
global_shadow_state1,
global_shadow_state2,
slice_sub,
)


Expand Down Expand Up @@ -58,7 +59,7 @@ def classical_shadow(psi, pauli_strings, status):

@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
def test_state(backend):
nq, ns = 2, 6000
nq, ns = 2, 10000

c = tc.Circuit(nq)
c.H(0)
Expand All @@ -72,15 +73,13 @@ def test_state(backend):
lss_states = shadow_snapshots(c.state(), pauli_strings, status)
sdw_state = global_shadow_state(lss_states)

R = np.array(sdw_state - bell_state)
error = np.sqrt(np.trace(R.conj().T @ R))
assert error < 0.1
np.allclose(sdw_state, bell_state, atol=0.01)


# @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
# def test_expc(backend):
# import pennylane as qml
# nq, ns = 8, 100000
# nq, ns = 8, 200000
#
# c = tc.Circuit(nq)
# for i in range(nq):
Expand All @@ -100,13 +99,13 @@ def test_state(backend):
# )
#
# expc = np.median(expection_ps_shadow(snapshots, pauli_strings, ps=ps, k=9))
# ent = entropy_shadow(snapshots, pauli_strings, range(4), alpha=2)
# ent = entropy_shadow(slice_sub(snapshots, range(4)), slice_sub(pauli_strings, range(4)), alpha=2)
#
# shadow = qml.ClassicalShadow(np.asarray(snapshots[:, 0]), np.asarray(pauli_strings - 1)) # repeat == 1
# H = qml.PauliX(0) @ qml.PauliX(6) @ qml.PauliY(2)@ qml.PauliY(3) @ qml.PauliZ(5) @ qml.PauliZ(7)
# pl_expc = shadow.expval(H, k=9)
# pl_ent = shadow.entropy(range(4), alpha=2)
#
# print(np.isclose(expc, pl_expc), np.isclose(ent, pl_ent))
#
# assert np.isclose(expc, pl_expc)
# assert np.isclose(ent, pl_ent)
# assert np.abs(expc - exact) < 0.1

0 comments on commit f2185bd

Please sign in to comment.