Skip to content

Commit

Permalink
Initial attempt to update ReflectionUsingPrepare with lots of failing…
Browse files Browse the repository at this point in the history
… tests
  • Loading branch information
tanujkhattar committed Aug 27, 2024
1 parent 862e7e7 commit dcdb655
Show file tree
Hide file tree
Showing 3 changed files with 339 additions and 143 deletions.
93 changes: 83 additions & 10 deletions qualtran/bloqs/mcmt/multi_control_pauli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
from functools import cached_property
from typing import Dict, Set, Tuple, TYPE_CHECKING, Union

Expand All @@ -28,15 +28,17 @@
CtrlSpec,
DecomposeTypeError,
GateWithRegisters,
QAny,
QBit,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates import XGate
from qualtran.bloqs.mcmt.and_bloq import _to_tuple_or_has_length, is_symbolic
from qualtran.bloqs.mcmt.controlled_via_and import ControlledViaAnd
from qualtran.symbolics import HasLength, SymbolicInt
from qualtran.symbolics import HasLength, slen, SymbolicInt

if TYPE_CHECKING:
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand Down Expand Up @@ -184,26 +186,97 @@ def _ccpauli_symb() -> MultiControlPauli:


@frozen
class MultiControlX(MultiControlPauli):
class _MultiControlPauli(Bloq):
"""Abstract base class for a generalized implementation of Multi-Control Pauli"""

cvs: CtrlSpec

@property
@abc.abstractmethod
def _target_gate(self) -> cirq.Pauli:
...

@cached_property
def signature(self) -> 'Signature':
return Signature([*self.ctrl_regs, Register('target', QBit())])

@cached_property
def ctrl_regs(self) -> Tuple[Register, ...]:
ctrl_regs = []
for i, (dtype, shape) in enumerate(self.cvs.activation_function_dtypes()):
if is_symbolic(dtype.num_qubits) or dtype.num_qubits > 0:
ctrl_regs.append(Register(f'ctrl{i}_', dtype=dtype, shape=shape))
return tuple(ctrl_regs)

@cached_property
def flat_cvs(self) -> Union[HasLength, Tuple[int, ...]]:
return tuple(
b
for cvs, qdtype in zip(self.cvs.cvs, self.cvs.qdtypes)
for cv in tuple(cvs.reshape(-1))
for b in qdtype.to_bits(cv)
)

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
flat_ctrls = []
for reg in self.ctrl_regs:
if reg.shape:
for soq in soqs[reg.name].reshape(-1):
flat_ctrls += bb.split(soq).tolist() if reg.bitsize > 1 else [soq]
else:
soq = soqs[reg.name]
assert isinstance(soq, Soquet)
flat_ctrls += bb.split(soq).tolist() if reg.bitsize > 1 else [soq]
print(self.flat_cvs)
print(flat_ctrls)
flat_ctrls, target = bb.add(
MultiControlPauli(self.flat_cvs, self._target_gate),
controls=flat_ctrls,
target=soqs['target'],
)
ctrls = {}
st = 0
for reg in self.ctrl_regs:
if reg.shape:
curr_soqs = np.empty(reg.shape, dtype=object)
for idx in reg.all_idxs():
if reg.bitsize > 1:
curr_soqs[idx] = bb.join(flat_ctrls[st : st + reg.bitsize], dtype=reg.dtype)
else:
curr_soqs[idx] = flat_ctrls[st]
st += reg.bitsize
ctrls[reg.name] = curr_soqs
else:
if reg.bitsize > 1:
ctrls[reg.name] = bb.join(flat_ctrls[st : st + reg.bitsize], dtype=reg.dtype)
else:
ctrls[reg.name] = flat_ctrls[st]
st += reg.bitsize
return ctrls | {'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(MultiControlPauli(self.flat_cvs, self._target_gate), 1)}


@frozen
class MultiControlX(_MultiControlPauli):
r"""Implements multi-control, single-target X gate.
See :class:`MultiControlPauli` for implementation and costs.
"""
target_gate: cirq.Pauli = field(init=False)

@target_gate.default
def _X(self):
@property
def _target_gate(self) -> cirq.Pauli:
return cirq.X


@frozen
class MultiControlZ(MultiControlPauli):
class MultiControlZ(_MultiControlPauli):
r"""Implements multi-control, single-target Z gate.
See :class:`MultiControlPauli` for implementation and costs.
"""
target_gate: cirq.Pauli = field(init=False)

@target_gate.default
def _Z(self):
@property
def _target_gate(self) -> cirq.Pauli:
return cirq.Z
124 changes: 79 additions & 45 deletions qualtran/bloqs/reflections/reflection_using_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,25 @@
# limitations under the License.

from functools import cached_property
from typing import Iterator, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
from typing import Dict, Iterator, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union

import attrs
import cirq
import numpy as np
from numpy.typing import NDArray

from qualtran import Bloq, bloq_example, BloqDocSpec, CtrlSpec, QBit, Register, Signature
from qualtran import Bloq, bloq_example, BloqDocSpec, CtrlSpec, QBit, Register, Side, Signature
from qualtran._infra.gate_with_registers import GateWithRegisters, merge_qubits, total_bits
from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension
from qualtran.bloqs.basic_gates.global_phase import GlobalPhase
from qualtran.bloqs.basic_gates.x_basis import XGate
from qualtran.bloqs.basic_gates import GlobalPhase, XGate, ZPowGate
from qualtran.bloqs.mcmt import MultiControlZ
from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity
from qualtran.drawing import Circle, ModPlus, TextBox
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.symbolics import HasLength, is_symbolic, SymbolicInt
from qualtran.symbolics import HasLength, is_symbolic, pi, sarg, SymbolicFloat, SymbolicInt

if TYPE_CHECKING:
from qualtran import BloqBuilder, SoquetT
from qualtran.bloqs.block_encoding.lcu_block_encoding import BlackBoxPrepare
from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand Down Expand Up @@ -78,7 +79,7 @@ class ReflectionUsingPrepare(GateWithRegisters, SpecializedSingleQubitControlled
prepare_gate: Union['PrepareOracle', 'BlackBoxPrepare']
control_val: Optional[int] = None
global_phase: complex = 1
eps: float = 1e-11
eps: SymbolicFloat = 1e-11

@cached_property
def control_registers(self) -> Tuple[Register, ...]:
Expand All @@ -88,9 +89,13 @@ def control_registers(self) -> Tuple[Register, ...]:
def selection_registers(self) -> Tuple[Register, ...]:
return self.prepare_gate.selection_registers

@cached_property
def junk_registers(self) -> Tuple[Register, ...]:
return self.prepare_gate.junk_registers

@cached_property
def signature(self) -> Signature:
return Signature([*self.control_registers, *self.selection_registers])
return Signature([*self.control_registers, *self.selection_registers, *self.junk_registers])

@classmethod
def reflection_around_zero(
Expand All @@ -117,58 +122,79 @@ def reflection_around_zero(
prepare_gate, control_val=control_val, global_phase=global_phase, eps=eps
)

def decompose_from_registers(
self,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> Iterator[cirq.OP_TREE]:
qm = context.qubit_manager
def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
# 0. Allocate new ancillas, if needed.
phase_target = qm.qalloc(1)[0] if self.control_val is None else quregs.pop('control')[0]
state_prep_ancilla = {
reg.name: np.array(qm.qalloc(reg.total_bits())).reshape(reg.shape + (reg.bitsize,))
for reg in self.prepare_gate.junk_registers
}
state_prep_selection_regs = quregs
prepare_op = self.prepare_gate.on_registers(
**state_prep_selection_regs, **state_prep_ancilla
phase_target = (
bb.allocate(dtype=QBit()) if self.control_val is None else soqs.pop('control')
)
# prep_soqs =
# state_prep_junk = {
# reg.name: soqs.pop(reg.name)
# for reg in self.prepare_gate.junk_registers
# if reg.side & Side.LEFT
# }
# state_prep_sel = soqs
# state_prep_ancilla = {
# reg.name: np.array(bb.qalloc(reg.total_bits())).reshape(reg.shape + (reg.bitsize,))
# for reg in self.prepare_gate.junk_registers
# }
# state_prep_selection_regs = quregs
# prepare_op = self.prepare_gate.on_registers(
# **state_prep_selection_regs, **state_prep_ancilla
# )
# 1. PREPARE†
yield cirq.inverse(prepare_op)
soqs = bb.add_d(self.prepare_gate.adjoint(), **soqs)
# yield cirq.inverse(prepare_op)
# 2. MultiControlled Z, controlled on |000..00> state.
phase_control = np.array(
merge_qubits(self.selection_registers, **state_prep_selection_regs)
)
yield cirq.X(phase_target) if not self.control_val else []
yield MultiControlZ([0] * len(phase_control)).on_registers(
controls=phase_control.reshape(phase_control.shape + (1,)), target=phase_target
)
# phase_control = np.array(
# merge_qubits(self.selection_registers, **state_prep_selection_regs)
# )
if self.control_val is None:
phase_target = bb.add(XGate(), q=phase_target)
multi_ctrl_soqs = {
ctrl_reg.name: soqs.pop(sel_reg.name)
for ctrl_reg, sel_reg in zip(self._multi_ctrl_z.ctrl_regs, self.selection_registers)
}
multi_ctrl_soqs = bb.add_d(self._multi_ctrl_z, **multi_ctrl_soqs, target=phase_target)
for ctrl_reg, sel_reg in zip(self._multi_ctrl_z.ctrl_regs, self.selection_registers):
soqs[sel_reg.name] = multi_ctrl_soqs.pop(ctrl_reg.name)
phase_target = multi_ctrl_soqs.pop('target')
assert not multi_ctrl_soqs
if self.global_phase != 1:
exponent = sarg(self.global_phase) / pi(self.global_phase)
if self.control_val is None:
yield cirq.global_phase_operation(self.global_phase, atol=self.eps)
bb.add(GlobalPhase(exponent=exponent, eps=self.eps))
else:
yield cirq.Z(phase_target) ** (np.angle(self.global_phase) / np.pi)
yield cirq.X(phase_target) if not self.control_val else []
phase_target = bb.add(ZPowGate(exponent=exponent, eps=self.eps), q=phase_target)
# 3. PREPARE
yield prepare_op

# 4. Deallocate ancilla.
qm.qfree([q for anc in state_prep_ancilla.values() for q in anc.flatten()])
soqs = bb.add_d(self.prepare_gate, **soqs)
# 4. Deallocate phase_target.
if self.control_val is None:
qm.qfree([phase_target])

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = ['@' if self.control_val else '@(0)'] * total_bits(self.control_registers)
wire_symbols += ['R_L'] * total_bits(self.selection_registers)
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
phase_target = bb.add(XGate(), q=phase_target)
bb.free(phase_target)
else:
soqs['control'] = phase_target
return soqs

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
if reg.name == 'control':
return Circle(filled=True)
if reg.name in {reg.name for reg in self.selection_registers}:
return f'R_L[{",".join(str(i) for i in idx)}]'
return self.prepare_gate.wire_symbol(reg, idx)

# def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
# wire_symbols = ['@' if self.control_val else '@(0)'] * total_bits(self.control_registers)
# wire_symbols += ['R_L'] * total_bits(self.selection_registers)
# return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
n_phase_control = sum(reg.total_bits() for reg in self.selection_registers)
cvs = HasLength(n_phase_control) if is_symbolic(n_phase_control) else [0] * n_phase_control
# n_phase_control = sum(reg.total_bits() for reg in self.selection_registers)
# cvs = HasLength(n_phase_control) if is_symbolic(n_phase_control) else [0] * n_phase_control
costs: Set['BloqCountT'] = {
(self.prepare_gate, 1),
(self.prepare_gate.adjoint(), 1),
(MultiControlZ(cvs), 1),
(self._multi_ctrl_z, 1),
}
if self.control_val is None:
costs.add((XGate(), 2))
Expand All @@ -179,6 +205,14 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
costs.add((phase_op, 1))
return costs

@cached_property
def _multi_ctrl_z(self) -> MultiControlZ:
ctrl_soqs, dtypes, cvs = {}, [], []
for reg in self.selection_registers:
dtypes.append(reg.dtype)
cvs.append(np.zeros(reg.shape, dtype=int))
return MultiControlZ(CtrlSpec(tuple(dtypes), tuple(cvs)))


@bloq_example(generalizer=ignore_split_join)
def _refl_using_prep() -> ReflectionUsingPrepare:
Expand Down
Loading

0 comments on commit dcdb655

Please sign in to comment.