Skip to content

ReflectionUsingPrepare refactor #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 83 additions & 10 deletions qualtran/bloqs/mcmt/multi_control_pauli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
from functools import cached_property
from typing import Dict, Set, Tuple, TYPE_CHECKING, Union

Expand All @@ -28,15 +28,17 @@
CtrlSpec,
DecomposeTypeError,
GateWithRegisters,
QAny,
QBit,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates import XGate
from qualtran.bloqs.mcmt.and_bloq import _to_tuple_or_has_length, is_symbolic
from qualtran.bloqs.mcmt.controlled_via_and import ControlledViaAnd
from qualtran.symbolics import HasLength, SymbolicInt
from qualtran.symbolics import HasLength, slen, SymbolicInt

if TYPE_CHECKING:
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand Down Expand Up @@ -184,26 +186,97 @@ def _ccpauli_symb() -> MultiControlPauli:


@frozen
class MultiControlX(MultiControlPauli):
class _MultiControlPauli(Bloq):
"""Abstract base class for a generalized implementation of Multi-Control Pauli"""

cvs: CtrlSpec

@property
@abc.abstractmethod
def _target_gate(self) -> cirq.Pauli:
...

@cached_property
def signature(self) -> 'Signature':
return Signature([*self.ctrl_regs, Register('target', QBit())])

@cached_property
def ctrl_regs(self) -> Tuple[Register, ...]:
ctrl_regs = []
for i, (dtype, shape) in enumerate(self.cvs.activation_function_dtypes()):
if is_symbolic(dtype.num_qubits) or dtype.num_qubits > 0:
ctrl_regs.append(Register(f'ctrl{i}_', dtype=dtype, shape=shape))
return tuple(ctrl_regs)

@cached_property
def flat_cvs(self) -> Union[HasLength, Tuple[int, ...]]:
return tuple(
b
for cvs, qdtype in zip(self.cvs.cvs, self.cvs.qdtypes)
for cv in tuple(cvs.reshape(-1))
for b in qdtype.to_bits(cv)
)

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
flat_ctrls = []
for reg in self.ctrl_regs:
if reg.shape:
for soq in soqs[reg.name].reshape(-1):
flat_ctrls += bb.split(soq).tolist() if reg.bitsize > 1 else [soq]
else:
soq = soqs[reg.name]
assert isinstance(soq, Soquet)
flat_ctrls += bb.split(soq).tolist() if reg.bitsize > 1 else [soq]
print(self.flat_cvs)
print(flat_ctrls)
flat_ctrls, target = bb.add(
MultiControlPauli(self.flat_cvs, self._target_gate),
controls=flat_ctrls,
target=soqs['target'],
)
ctrls = {}
st = 0
for reg in self.ctrl_regs:
if reg.shape:
curr_soqs = np.empty(reg.shape, dtype=object)
for idx in reg.all_idxs():
if reg.bitsize > 1:
curr_soqs[idx] = bb.join(flat_ctrls[st : st + reg.bitsize], dtype=reg.dtype)
else:
curr_soqs[idx] = flat_ctrls[st]
st += reg.bitsize
ctrls[reg.name] = curr_soqs
else:
if reg.bitsize > 1:
ctrls[reg.name] = bb.join(flat_ctrls[st : st + reg.bitsize], dtype=reg.dtype)
else:
ctrls[reg.name] = flat_ctrls[st]
st += reg.bitsize
return ctrls | {'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(MultiControlPauli(self.flat_cvs, self._target_gate), 1)}


@frozen
class MultiControlX(_MultiControlPauli):
r"""Implements multi-control, single-target X gate.

See :class:`MultiControlPauli` for implementation and costs.
"""
target_gate: cirq.Pauli = field(init=False)

@target_gate.default
def _X(self):
@property
def _target_gate(self) -> cirq.Pauli:
return cirq.X


@frozen
class MultiControlZ(MultiControlPauli):
class MultiControlZ(_MultiControlPauli):
r"""Implements multi-control, single-target Z gate.

See :class:`MultiControlPauli` for implementation and costs.
"""
target_gate: cirq.Pauli = field(init=False)

@target_gate.default
def _Z(self):
@property
def _target_gate(self) -> cirq.Pauli:
return cirq.Z
124 changes: 79 additions & 45 deletions qualtran/bloqs/reflections/reflection_using_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,25 @@
# limitations under the License.

from functools import cached_property
from typing import Iterator, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
from typing import Dict, Iterator, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union

import attrs
import cirq
import numpy as np
from numpy.typing import NDArray

from qualtran import Bloq, bloq_example, BloqDocSpec, CtrlSpec, QBit, Register, Signature
from qualtran import Bloq, bloq_example, BloqDocSpec, CtrlSpec, QBit, Register, Side, Signature
from qualtran._infra.gate_with_registers import GateWithRegisters, merge_qubits, total_bits
from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension
from qualtran.bloqs.basic_gates.global_phase import GlobalPhase
from qualtran.bloqs.basic_gates.x_basis import XGate
from qualtran.bloqs.basic_gates import GlobalPhase, XGate, ZPowGate
from qualtran.bloqs.mcmt import MultiControlZ
from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity
from qualtran.drawing import Circle, ModPlus, TextBox
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.symbolics import HasLength, is_symbolic, SymbolicInt
from qualtran.symbolics import HasLength, is_symbolic, pi, sarg, SymbolicFloat, SymbolicInt

if TYPE_CHECKING:
from qualtran import BloqBuilder, SoquetT
from qualtran.bloqs.block_encoding.lcu_block_encoding import BlackBoxPrepare
from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand Down Expand Up @@ -78,7 +79,7 @@ class ReflectionUsingPrepare(GateWithRegisters, SpecializedSingleQubitControlled
prepare_gate: Union['PrepareOracle', 'BlackBoxPrepare']
control_val: Optional[int] = None
global_phase: complex = 1
eps: float = 1e-11
eps: SymbolicFloat = 1e-11

@cached_property
def control_registers(self) -> Tuple[Register, ...]:
Expand All @@ -88,9 +89,13 @@ def control_registers(self) -> Tuple[Register, ...]:
def selection_registers(self) -> Tuple[Register, ...]:
return self.prepare_gate.selection_registers

@cached_property
def junk_registers(self) -> Tuple[Register, ...]:
return self.prepare_gate.junk_registers

@cached_property
def signature(self) -> Signature:
return Signature([*self.control_registers, *self.selection_registers])
return Signature([*self.control_registers, *self.selection_registers, *self.junk_registers])

@classmethod
def reflection_around_zero(
Expand All @@ -117,58 +122,79 @@ def reflection_around_zero(
prepare_gate, control_val=control_val, global_phase=global_phase, eps=eps
)

def decompose_from_registers(
self,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> Iterator[cirq.OP_TREE]:
qm = context.qubit_manager
def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
# 0. Allocate new ancillas, if needed.
phase_target = qm.qalloc(1)[0] if self.control_val is None else quregs.pop('control')[0]
state_prep_ancilla = {
reg.name: np.array(qm.qalloc(reg.total_bits())).reshape(reg.shape + (reg.bitsize,))
for reg in self.prepare_gate.junk_registers
}
state_prep_selection_regs = quregs
prepare_op = self.prepare_gate.on_registers(
**state_prep_selection_regs, **state_prep_ancilla
phase_target = (
bb.allocate(dtype=QBit()) if self.control_val is None else soqs.pop('control')
)
# prep_soqs =
# state_prep_junk = {
# reg.name: soqs.pop(reg.name)
# for reg in self.prepare_gate.junk_registers
# if reg.side & Side.LEFT
# }
# state_prep_sel = soqs
# state_prep_ancilla = {
# reg.name: np.array(bb.qalloc(reg.total_bits())).reshape(reg.shape + (reg.bitsize,))
# for reg in self.prepare_gate.junk_registers
# }
# state_prep_selection_regs = quregs
# prepare_op = self.prepare_gate.on_registers(
# **state_prep_selection_regs, **state_prep_ancilla
# )
# 1. PREPARE†
yield cirq.inverse(prepare_op)
soqs = bb.add_d(self.prepare_gate.adjoint(), **soqs)
# yield cirq.inverse(prepare_op)
# 2. MultiControlled Z, controlled on |000..00> state.
phase_control = np.array(
merge_qubits(self.selection_registers, **state_prep_selection_regs)
)
yield cirq.X(phase_target) if not self.control_val else []
yield MultiControlZ([0] * len(phase_control)).on_registers(
controls=phase_control.reshape(phase_control.shape + (1,)), target=phase_target
)
# phase_control = np.array(
# merge_qubits(self.selection_registers, **state_prep_selection_regs)
# )
if self.control_val is None:
phase_target = bb.add(XGate(), q=phase_target)
multi_ctrl_soqs = {
ctrl_reg.name: soqs.pop(sel_reg.name)
for ctrl_reg, sel_reg in zip(self._multi_ctrl_z.ctrl_regs, self.selection_registers)
}
multi_ctrl_soqs = bb.add_d(self._multi_ctrl_z, **multi_ctrl_soqs, target=phase_target)
for ctrl_reg, sel_reg in zip(self._multi_ctrl_z.ctrl_regs, self.selection_registers):
soqs[sel_reg.name] = multi_ctrl_soqs.pop(ctrl_reg.name)
phase_target = multi_ctrl_soqs.pop('target')
assert not multi_ctrl_soqs
if self.global_phase != 1:
exponent = sarg(self.global_phase) / pi(self.global_phase)
if self.control_val is None:
yield cirq.global_phase_operation(self.global_phase, atol=self.eps)
bb.add(GlobalPhase(exponent=exponent, eps=self.eps))
else:
yield cirq.Z(phase_target) ** (np.angle(self.global_phase) / np.pi)
yield cirq.X(phase_target) if not self.control_val else []
phase_target = bb.add(ZPowGate(exponent=exponent, eps=self.eps), q=phase_target)
# 3. PREPARE
yield prepare_op

# 4. Deallocate ancilla.
qm.qfree([q for anc in state_prep_ancilla.values() for q in anc.flatten()])
soqs = bb.add_d(self.prepare_gate, **soqs)
# 4. Deallocate phase_target.
if self.control_val is None:
qm.qfree([phase_target])

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = ['@' if self.control_val else '@(0)'] * total_bits(self.control_registers)
wire_symbols += ['R_L'] * total_bits(self.selection_registers)
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
phase_target = bb.add(XGate(), q=phase_target)
bb.free(phase_target)
else:
soqs['control'] = phase_target
return soqs

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
if reg.name == 'control':
return Circle(filled=True)
if reg.name in {reg.name for reg in self.selection_registers}:
return f'R_L[{",".join(str(i) for i in idx)}]'
return self.prepare_gate.wire_symbol(reg, idx)

# def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
# wire_symbols = ['@' if self.control_val else '@(0)'] * total_bits(self.control_registers)
# wire_symbols += ['R_L'] * total_bits(self.selection_registers)
# return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
n_phase_control = sum(reg.total_bits() for reg in self.selection_registers)
cvs = HasLength(n_phase_control) if is_symbolic(n_phase_control) else [0] * n_phase_control
# n_phase_control = sum(reg.total_bits() for reg in self.selection_registers)
# cvs = HasLength(n_phase_control) if is_symbolic(n_phase_control) else [0] * n_phase_control
costs: Set['BloqCountT'] = {
(self.prepare_gate, 1),
(self.prepare_gate.adjoint(), 1),
(MultiControlZ(cvs), 1),
(self._multi_ctrl_z, 1),
}
if self.control_val is None:
costs.add((XGate(), 2))
Expand All @@ -179,6 +205,14 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
costs.add((phase_op, 1))
return costs

@cached_property
def _multi_ctrl_z(self) -> MultiControlZ:
ctrl_soqs, dtypes, cvs = {}, [], []
for reg in self.selection_registers:
dtypes.append(reg.dtype)
cvs.append(np.zeros(reg.shape, dtype=int))
return MultiControlZ(CtrlSpec(tuple(dtypes), tuple(cvs)))


@bloq_example(generalizer=ignore_split_join)
def _refl_using_prep() -> ReflectionUsingPrepare:
Expand Down
Loading
Loading