From 0d99a237236fcb52015e29371fd774f6c846e123 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 5 Jun 2024 00:03:00 -0700 Subject: [PATCH] Performance improvements in cirq interop --- qualtran/_infra/registers.py | 5 ++ qualtran/bloqs/data_loading/qrom.py | 6 +- .../multiplexers/unary_iteration_bloq.py | 6 +- qualtran/cirq_interop/_cirq_to_bloq.py | 76 ++++++++++--------- 4 files changed, 51 insertions(+), 42 deletions(-) diff --git a/qualtran/_infra/registers.py b/qualtran/_infra/registers.py index 8d440deef..9c41ca546 100644 --- a/qualtran/_infra/registers.py +++ b/qualtran/_infra/registers.py @@ -16,6 +16,7 @@ import enum import itertools from collections import defaultdict +from functools import cached_property from typing import cast, Dict, Iterable, Iterator, List, overload, Tuple, Union import attrs @@ -99,6 +100,10 @@ def total_bits(self) -> int: This is the product of bitsize and each of the dimensions in `shape`. """ + return self._total_bits + + @cached_property + def _total_bits(self) -> int: return self.bitsize * int(np.prod(self.shape)) def adjoint(self) -> 'Register': diff --git a/qualtran/bloqs/data_loading/qrom.py b/qualtran/bloqs/data_loading/qrom.py index f1c82b9a8..2b99dc372 100644 --- a/qualtran/bloqs/data_loading/qrom.py +++ b/qualtran/bloqs/data_loading/qrom.py @@ -145,7 +145,7 @@ def decompose_zero_selection( if self.num_controls == 0: yield self._load_nth_data(zero_indx, cirq.X, **target_regs) elif self.num_controls == 1: - yield self._load_nth_data(zero_indx, lambda q: CNOT().on(controls[0], q), **target_regs) + yield self._load_nth_data(zero_indx, lambda q: cirq.CNOT(controls[0], q), **target_regs) else: ctrl = np.array(controls)[:, np.newaxis] junk = np.array(context.qubit_manager.qalloc(len(controls) - 2))[:, np.newaxis] @@ -157,7 +157,7 @@ def decompose_zero_selection( ctrl=ctrl, junk=junk, target=and_target ) yield multi_controlled_and - yield self._load_nth_data(zero_indx, lambda q: CNOT().on(and_target, q), **target_regs) + yield self._load_nth_data(zero_indx, lambda q: cirq.CNOT(and_target, q), **target_regs) yield cirq.inverse(multi_controlled_and) context.qubit_manager.qfree(list(junk.flatten()) + [and_target]) @@ -178,7 +178,7 @@ def nth_operation( ) -> Iterator[cirq.OP_TREE]: selection_idx = tuple(kwargs[reg.name] for reg in self.selection_registers) target_regs = {reg.name: kwargs[reg.name] for reg in self.target_registers} - yield self._load_nth_data(selection_idx, lambda q: CNOT().on(control, q), **target_regs) + yield self._load_nth_data(selection_idx, lambda q: cirq.CNOT(control, q), **target_regs) def _circuit_diagram_info_(self, args) -> cirq.CircuitDiagramInfo: from qualtran.cirq_interop._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info diff --git a/qualtran/bloqs/multiplexers/unary_iteration_bloq.py b/qualtran/bloqs/multiplexers/unary_iteration_bloq.py index 8b9b9505a..f8d549f0c 100644 --- a/qualtran/bloqs/multiplexers/unary_iteration_bloq.py +++ b/qualtran/bloqs/multiplexers/unary_iteration_bloq.py @@ -108,7 +108,7 @@ def _unary_iteration_segtree( yield from _unary_iteration_segtree( ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early ) - ops.append(CNOT().on(control, anc)) + ops.append(cirq.CNOT(control, anc)) yield from _unary_iteration_segtree( ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early ) @@ -130,11 +130,11 @@ def _unary_iteration_zero_control( ops, selection[1:], ancilla, l_iter, r_iter, break_early ) return - ops.append(XGate().on(selection[0])) + ops.append(cirq.X(selection[0])) yield from _unary_iteration_segtree( ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter, break_early ) - ops.append(XGate().on(selection[0])) + ops.append(cirq.X(selection[0])) yield from _unary_iteration_segtree( ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter, break_early ) diff --git a/qualtran/cirq_interop/_cirq_to_bloq.py b/qualtran/cirq_interop/_cirq_to_bloq.py index 004e1f950..9c2f78680 100644 --- a/qualtran/cirq_interop/_cirq_to_bloq.py +++ b/qualtran/cirq_interop/_cirq_to_bloq.py @@ -16,8 +16,8 @@ import abc import itertools import numbers -from functools import cached_property -from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypeVar, Union +from functools import cache, cached_property +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, TypeVar, Union import cirq import numpy as np @@ -342,49 +342,23 @@ def _gather_input_soqs( return qvars_in -def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq: - from qualtran import Adjoint +@cache +def _cirq_gate_to_bloq_map() -> Dict[cirq.Gate, Bloq]: + # Check specific basic gates instances. from qualtran.bloqs.basic_gates import ( CNOT, CSwap, - CZPowGate, - GlobalPhase, Hadamard, - Rx, - Ry, - Rz, SGate, TGate, Toffoli, TwoBitSwap, XGate, - XPowGate, YGate, - YPowGate, ZGate, - ZPowGate, ) - from qualtran.cirq_interop import CirqGateAsBloq - from qualtran.cirq_interop._bloq_to_cirq import BloqAsCirqGate - - if isinstance(gate, BloqAsCirqGate): - # Perhaps this operation was constructed from `Bloq.on()`. - return gate.bloq - if isinstance(gate, Bloq): - # I.e., `GateWithRegisters`. - return gate - - if isinstance(gate, cirq.ops.raw_types._InverseCompositeGate): - # Inverse of a cirq gate, delegate to Adjoint - return Adjoint(_cirq_gate_to_bloq(gate._original)) - if isinstance(gate, cirq.ControlledGate): - return Controlled( - _cirq_gate_to_bloq(gate.sub_gate), CtrlSpec.from_cirq_cv(gate.control_values) - ) - - # Check specific basic gates instances. - CIRQ_GATE_TO_BLOQ_MAP = { + return { cirq.T: TGate(), cirq.T**-1: TGate().adjoint(), cirq.S: SGate(), @@ -398,11 +372,13 @@ def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq: cirq.SWAP: TwoBitSwap(), cirq.CSWAP: CSwap(1), } - if gate in CIRQ_GATE_TO_BLOQ_MAP: - return CIRQ_GATE_TO_BLOQ_MAP[gate] - # Check specific basic gates types. - CIRQ_TYPE_TO_BLOQ_MAP = { + +@cache +def _cirq_type_to_bloq_map() -> Dict[Type[cirq.Gate], Type[Bloq]]: + from qualtran.bloqs.basic_gates import CZPowGate, Rx, Ry, Rz, XPowGate, YPowGate, ZPowGate + + return { cirq.Rz: Rz, cirq.Rx: Rx, cirq.Ry: Ry, @@ -411,6 +387,34 @@ def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq: cirq.ZPowGate: ZPowGate, cirq.CZPowGate: CZPowGate, } + + +def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq: + from qualtran import Adjoint + from qualtran.bloqs.basic_gates import GlobalPhase + from qualtran.cirq_interop import CirqGateAsBloq + from qualtran.cirq_interop._bloq_to_cirq import BloqAsCirqGate + + if isinstance(gate, BloqAsCirqGate): + # Perhaps this operation was constructed from `Bloq.on()`. + return gate.bloq + if isinstance(gate, Bloq): + # I.e., `GateWithRegisters`. + return gate + + if isinstance(gate, cirq.ops.raw_types._InverseCompositeGate): + # Inverse of a cirq gate, delegate to Adjoint + return Adjoint(_cirq_gate_to_bloq(gate._original)) + + if isinstance(gate, cirq.ControlledGate): + return Controlled( + _cirq_gate_to_bloq(gate.sub_gate), CtrlSpec.from_cirq_cv(gate.control_values) + ) + CIRQ_GATE_TO_BLOQ_MAP = _cirq_gate_to_bloq_map() + if gate in CIRQ_GATE_TO_BLOQ_MAP: + return CIRQ_GATE_TO_BLOQ_MAP[gate] + CIRQ_TYPE_TO_BLOQ_MAP = _cirq_type_to_bloq_map() + # Check specific basic gates types. if isinstance(gate, (cirq.Rx, cirq.Ry, cirq.Rz)): return CIRQ_TYPE_TO_BLOQ_MAP[gate.__class__](angle=gate._rads)