Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make And a leaf bloq #1513

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions qualtran/bloqs/arithmetic/permutation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ def test_permutation_cycle_unitary_and_call_graph():
)

cv = sympy.Symbol('cv')
_, sigma = bloq.call_graph(
generalizer=[ignore_split_join, generalize_cvs], keep=lambda b: isinstance(b, And)
)
_, sigma = bloq.call_graph(generalizer=[ignore_split_join, generalize_cvs])
assert sigma == {
CNOT(): 8,
And(cv1=cv, cv2=cv): 4,
Expand All @@ -106,7 +104,7 @@ def test_permutation_cycle_symbolic_call_graph():
bloq = _permutation_cycle_symb()
logN, L = ceil(log2(bloq.N)), slen(bloq.cycle)

_, sigma = bloq.call_graph(keep=lambda b: isinstance(b, And))
_, sigma = bloq.call_graph()
assert sigma == {
And(): (L + 1) * (logN - 1),
And().adjoint(): (L + 1) * (logN - 1),
Expand All @@ -133,7 +131,7 @@ def test_permutation_unitary_and_call_graph():
),
)

_, sigma = bloq.call_graph(generalizer=ignore_split_join, keep=lambda b: isinstance(b, And))
_, sigma = bloq.call_graph(generalizer=ignore_split_join)
assert sigma == {
CNOT(): 17,
And(): 56 // 4,
Expand All @@ -160,7 +158,7 @@ def test_permutation_symbolic_call_graph():
logN = ceil(log2(N))
bloq = _permutation_symb()

_, sigma = bloq.call_graph(keep=lambda b: isinstance(b, And))
_, sigma = bloq.call_graph()
assert sigma == {
And().adjoint(): (N + 1) * (logN - 1),
And(): (N + 1) * (logN - 1),
Expand Down
7 changes: 3 additions & 4 deletions qualtran/bloqs/chemistry/df/double_factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates import CSwap, Hadamard, Toffoli
from qualtran.bloqs.basic_gates import CSwap, Hadamard
from qualtran.bloqs.block_encoding import BlockEncoding
from qualtran.bloqs.bookkeeping import ArbitraryClifford
from qualtran.bloqs.chemistry.black_boxes import ApplyControlledZs
from qualtran.bloqs.chemistry.df.prepare import (
InnerPrepareDoubleFactorization,
Expand Down Expand Up @@ -280,10 +279,10 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
in_prep_dag: 1, # in_prep_l^dag
rot: 1, # rotate into system basis listing 4 pg 54
# apply CCZ first then CCCZ, the cost is 1 + 2 Toffolis (step 4e, and 7)
Toffoli(): 1,
ApplyControlledZs(cvs=(1, 1), bitsize=5): 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this incorrect before? I see the change makes it match the comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous test was only checking the sigma for the call graph. I updated it to check for equivalent bloq counts, which flagged this inconsistency.

rot_dag: 1, # Undo rotations
CSwap(self.num_spin_orb // 2): 2, # Swaps for spins
ArbitraryClifford(n=1): 1, # 2 Hadamards for spin superposition
Hadamard(): 2, # 2 Hadamards for spin superposition
}

def __str__(self) -> str:
Expand Down
9 changes: 3 additions & 6 deletions qualtran/bloqs/chemistry/df/double_factorization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openfermion.resource_estimates.df.compute_cost_df import compute_cost
from openfermion.resource_estimates.utils import power_two

from qualtran.bloqs.basic_gates import TGate
import qualtran.testing as qlt_testing
from qualtran.bloqs.chemistry.df.double_factorization import (
_df_block_encoding,
_df_one_body,
Expand All @@ -26,7 +26,6 @@
PrepareUniformSuperposition,
)
from qualtran.resource_counting import get_cost_value, QECGatesCost
from qualtran.testing import execute_notebook


def test_df_block_encoding(bloq_autotester):
Expand All @@ -51,9 +50,7 @@ def test_compare_cost_one_body_decomp():
num_bits_rot_aa=7,
num_bits_rot=num_bits_rot,
)
costs = bloq.call_graph()[1]
cbloq_costs = bloq.decompose_bloq().call_graph()[1]
assert costs[TGate()] == cbloq_costs[TGate()]
qlt_testing.assert_equivalent_bloq_counts(bloq)


def test_compare_cost_to_openfermion():
Expand Down Expand Up @@ -106,4 +103,4 @@ def test_compare_cost_to_openfermion():

@pytest.mark.notebook
def test_notebook():
execute_notebook("double_factorization")
qlt_testing.execute_notebook("double_factorization")
9 changes: 3 additions & 6 deletions qualtran/bloqs/chemistry/sf/single_factorization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openfermion.resource_estimates.sf.compute_cost_sf import compute_cost
from openfermion.resource_estimates.utils import power_two, QI, QI2, QR2

from qualtran.bloqs.basic_gates import TGate
import qualtran.testing as qlt_testing
from qualtran.bloqs.chemistry.sf.single_factorization import (
_sf_block_encoding,
_sf_one_body,
Expand All @@ -26,7 +26,6 @@
PrepareUniformSuperposition,
)
from qualtran.resource_counting import get_cost_value, QECGatesCost
from qualtran.testing import execute_notebook


def test_sf_block_encoding(bloq_autotester):
Expand All @@ -48,9 +47,7 @@ def test_compare_cost_one_body_decomp():
num_bits_state_prep=num_bits_state_prep,
num_bits_rot_aa=num_bits_rot_aa,
)
costs = bloq.call_graph()[1]
cbloq_costs = bloq.decompose_bloq().call_graph()[1]
assert costs[TGate()] == cbloq_costs[TGate()]
qlt_testing.assert_equivalent_bloq_counts(bloq)


def test_compare_cost_to_openfermion():
Expand Down Expand Up @@ -126,4 +123,4 @@ def test_compare_cost_to_openfermion():

@pytest.mark.notebook
def test_notebook():
execute_notebook("single_factorization")
qlt_testing.execute_notebook("single_factorization")
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import sympy
from numpy.typing import NDArray

from qualtran.bloqs.basic_gates import TGate, TwoBitCSwap
from qualtran.bloqs.for_testing.matrix_gate import MatrixGate
from qualtran.bloqs.for_testing.random_select_and_prepare import random_qubitization_walk_operator
from qualtran.bloqs.hamiltonian_simulation.hamiltonian_simulation_by_gqsp import (
Expand All @@ -34,7 +33,7 @@
)
from qualtran.bloqs.qubitization.qubitization_walk_operator import QubitizationWalkOperator
from qualtran.cirq_interop import BloqAsCirqGate
from qualtran.resource_counting import big_O, BloqCount, get_cost_value, QECGatesCost, QubitCount
from qualtran.resource_counting import big_O, get_cost_value, QECGatesCost, QubitCount
from qualtran.symbolics import Shaped


Expand Down Expand Up @@ -109,8 +108,10 @@ def test_hamiltonian_simulation_by_gqsp_t_complexity():
hubbard_time_evolution_by_gqsp = _hubbard_time_evolution_by_gqsp.make()
t_comp = hubbard_time_evolution_by_gqsp.t_complexity()

counts = get_cost_value(hubbard_time_evolution_by_gqsp, BloqCount.for_gateset('t+tof+cswap'))
assert t_comp.t == counts[TwoBitCSwap()] * 7 + counts[TGate()]
counts = get_cost_value(hubbard_time_evolution_by_gqsp, QECGatesCost())
t_comp_from_qec = counts.to_legacy_t_complexity()
assert t_comp.t == t_comp_from_qec.t
assert t_comp.rotations == t_comp_from_qec.rotations


def test_symbolic_t_cost():
Expand Down
49 changes: 5 additions & 44 deletions qualtran/bloqs/mcmt/and_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,17 @@
Side,
Signature,
)
from qualtran.bloqs.basic_gates import TGate, XGate
from qualtran.bloqs.bookkeeping import ArbitraryClifford
from qualtran.bloqs.basic_gates import XGate
from qualtran.cirq_interop import decompose_from_cirq_style_method
from qualtran.drawing import Circle, directional_text_box, Text, WireSymbol
from qualtran.resource_counting import (
big_O,
BloqCountDictT,
MutableBloqCountDictT,
SympySymbolAllocator,
)
from qualtran.resource_counting.generalizers import (
cirq_to_bloqs,
generalize_cvs,
generalize_rotation_angle,
ignore_alloc_free,
ignore_cliffords,
)
from qualtran.resource_counting import BloqCountDictT, MutableBloqCountDictT, SympySymbolAllocator
from qualtran.resource_counting.generalizers import generalize_cvs, ignore_cliffords
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import HasLength, is_symbolic, SymbolicInt

if TYPE_CHECKING:
import quimb.tensor as qtn

# TODO: https://github.com/quantumlib/Qualtran/issues/1346
FLAG_AND_AS_LEAF = False


@frozen
class And(GateWithRegisters):
Expand Down Expand Up @@ -108,22 +93,7 @@ def adjoint(self) -> 'And':
return attrs.evolve(self, uncompute=not self.uncompute)

def decompose_bloq(self) -> 'CompositeBloq':
if FLAG_AND_AS_LEAF:
raise DecomposeTypeError(f"{self} is atomic.")
return decompose_from_cirq_style_method(self)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if FLAG_AND_AS_LEAF:
raise DecomposeTypeError(f"{self} is atomic.")

if isinstance(self.cv1, sympy.Expr) or isinstance(self.cv2, sympy.Expr):
pre_post_cliffords: Union[sympy.Order, int] = big_O(1)
else:
pre_post_cliffords = 2 - self.cv1 - self.cv2
if self.uncompute:
return {ArbitraryClifford(n=2): 4 + 2 * pre_post_cliffords}

return {ArbitraryClifford(n=2): 9 + 2 * pre_post_cliffords, TGate(): 4}
raise DecomposeTypeError(f"{self} is atomic.")

def on_classical_vals(
self, *, ctrl: NDArray[np.uint8], target: Optional[int] = None
Expand Down Expand Up @@ -243,13 +213,6 @@ def to_clifford_t_circuit(self) -> 'cirq.FrozenCircuit':
circuit += pre_post_ops
return circuit.freeze()

def __pow__(self, power: int) -> 'And':
if power == 1:
return self
if power == -1:
return self.adjoint()
return NotImplemented # pragma: no cover

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
controls = ["(0)", "@"]
target = "And†" if self.uncompute else "And"
Expand All @@ -263,9 +226,7 @@ def _has_unitary_(self) -> bool:
return not self.uncompute


@bloq_example(
generalizer=[cirq_to_bloqs, ignore_cliffords, ignore_alloc_free, generalize_rotation_angle]
)
@bloq_example()
def _and_bloq() -> And:
and_bloq = And()
return and_bloq
Expand Down
20 changes: 5 additions & 15 deletions qualtran/bloqs/mcmt/specialized_ctrl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@
from qualtran.resource_counting import CostKey, GateCounts, get_cost_value, QECGatesCost


def _keep_and(b):
# TODO remove this after https://github.com/quantumlib/Qualtran/issues/1346 is resolved.
return isinstance(b, And)


@attrs.frozen
class AtomWithSpecializedControl(Bloq):
cv: Optional[int] = None
Expand Down Expand Up @@ -169,30 +164,25 @@ def test_bloq_with_controlled_bloq():
assert TestAtom('g').controlled() == CTestAtom('g')

ctrl_bloq = CTestAtom('g').controlled()
_, sigma = ctrl_bloq.call_graph(keep=_keep_and)
_, sigma = ctrl_bloq.call_graph()
assert sigma == {And(): 1, CTestAtom('g'): 1, And().adjoint(): 1}

ctrl_bloq = CTestAtom('n').controlled(CtrlSpec(cvs=0))
_, sigma = ctrl_bloq.call_graph(keep=_keep_and)
_, sigma = ctrl_bloq.call_graph()
assert sigma == {And(0, 1): 1, CTestAtom('n'): 1, And(0, 1).adjoint(): 1}

ctrl_bloq = TestAtom('nn').controlled(CtrlSpec(cvs=[0, 0]))
_, sigma = ctrl_bloq.call_graph(keep=_keep_and)
_, sigma = ctrl_bloq.call_graph()
assert sigma == {And(0, 0): 1, CTestAtom('nn'): 1, And(0, 0).adjoint(): 1}


def test_ctrl_adjoint():
assert TestAtom('a').adjoint().controlled() == CTestAtom('a').adjoint()

_, sigma = (
TestAtom('g')
.adjoint()
.controlled(ctrl_spec=CtrlSpec(cvs=[1, 1]))
.call_graph(keep=_keep_and)
)
_, sigma = TestAtom('g').adjoint().controlled(ctrl_spec=CtrlSpec(cvs=[1, 1])).call_graph()
assert sigma == {And(): 1, And().adjoint(): 1, CTestAtom('g').adjoint(): 1}

_, sigma = CTestAtom('c').adjoint().controlled().call_graph(keep=_keep_and)
_, sigma = CTestAtom('c').adjoint().controlled().call_graph()
assert sigma == {And(): 1, And().adjoint(): 1, CTestAtom('c').adjoint(): 1}

for cv in [0, 1]:
Expand Down
10 changes: 5 additions & 5 deletions qualtran/bloqs/multiplexers/apply_lth_bloq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import pytest

from qualtran import BloqBuilder, BQUInt, Controlled, CtrlSpec, QBit, Register, Signature, Soquet
from qualtran import BloqBuilder, BQUInt, QBit, Register, Signature, Soquet
from qualtran.bloqs.basic_gates import (
CHadamard,
CNOT,
Expand All @@ -34,7 +34,7 @@
ZeroState,
ZGate,
)
from qualtran.bloqs.bookkeeping.arbitrary_clifford import ArbitraryClifford
from qualtran.bloqs.mcmt import And
from qualtran.bloqs.multiplexers.apply_lth_bloq import _apply_lth_bloq, ApplyLthBloq
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.testing import assert_valid_bloq_decomposition
Expand Down Expand Up @@ -64,11 +64,11 @@ def test_call_graph():
_, sigma = _apply_lth_bloq().call_graph(generalizer=ignore_split_join)
assert sigma == {
CHadamard(): 1,
Controlled(TGate(), CtrlSpec()): 1,
TGate().controlled(): 1,
CZ(): 1,
CNOT(): 4,
TGate(): 12,
ArbitraryClifford(2): 45,
And(1, 0): 3,
And().adjoint(): 3,
}


Expand Down
6 changes: 1 addition & 5 deletions qualtran/drawing/bloq_counts_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def test_format_counts_sigma():
ret
== """\
#### Counts totals:
- `ArbitraryClifford(n=2)`: 45
- `T`: 20"""
- `And`: 5"""
)


Expand All @@ -43,9 +42,6 @@ def test_format_counts_graph_markdown():
== """\
- `MultiAnd(n=6)`
- `And`: $\\displaystyle 5$
- `And`
- `ArbitraryClifford(n=2)`: $\\displaystyle 9$
- `T`: $\\displaystyle 4$
"""
)

Expand Down
4 changes: 4 additions & 0 deletions qualtran/resource_counting/t_counts_from_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ def t_counts_from_sigma(sigma: Mapping['Bloq', SymbolicInt]) -> SymbolicInt:
import cirq

from qualtran.bloqs.basic_gates import TGate, Toffoli, TwoBitCSwap
from qualtran.bloqs.mcmt import And
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.resource_counting.classify_bloqs import bloq_is_rotation

ret = sigma.get(TGate(), 0) + sigma.get(TGate().adjoint(), 0)
ret += sigma.get(Toffoli(), 0) * 4
ret += sigma.get(TwoBitCSwap(), 0) * 7
for bloq, counts in sigma.items():
if isinstance(bloq, And) and not bloq.uncompute:
ret += counts * 4

if bloq_is_rotation(bloq) and not cirq.has_stabilizer_effect(bloq):
if isinstance(bloq, Controlled):
# TODO native controlled rotation bloqs missing (CRz, CRy etc.)
Expand Down
Loading