Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements in cirq interop #1054

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions qualtran/_infra/registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import enum
import itertools
from collections import defaultdict
from functools import cached_property
from typing import cast, Dict, Iterable, Iterator, List, overload, Tuple, Union

import attrs
Expand Down Expand Up @@ -99,6 +100,10 @@ def total_bits(self) -> int:

This is the product of bitsize and each of the dimensions in `shape`.
"""
return self._total_bits

@cached_property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use cached method on the original method?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

functools doesn't seem to have a cached_method and I didn't want to add another dependency. using just cache would require computing hash of the class itself which would again be slow.

def _total_bits(self) -> int:
return self.bitsize * int(np.prod(self.shape))

def adjoint(self) -> 'Register':
Expand Down
6 changes: 3 additions & 3 deletions qualtran/bloqs/data_loading/qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def decompose_zero_selection(
if self.num_controls == 0:
yield self._load_nth_data(zero_indx, cirq.X, **target_regs)
elif self.num_controls == 1:
yield self._load_nth_data(zero_indx, lambda q: CNOT().on(controls[0], q), **target_regs)
yield self._load_nth_data(zero_indx, lambda q: cirq.CNOT(controls[0], q), **target_regs)
else:
ctrl = np.array(controls)[:, np.newaxis]
junk = np.array(context.qubit_manager.qalloc(len(controls) - 2))[:, np.newaxis]
Expand All @@ -157,7 +157,7 @@ def decompose_zero_selection(
ctrl=ctrl, junk=junk, target=and_target
)
yield multi_controlled_and
yield self._load_nth_data(zero_indx, lambda q: CNOT().on(and_target, q), **target_regs)
yield self._load_nth_data(zero_indx, lambda q: cirq.CNOT(and_target, q), **target_regs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use qualtran CNOT so that it can immediately unwrap from CirqGateAsBloq without doing any logic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using qualtran CNOT is slower because CNOT in Qualtran is not a GateWithRegisters. So the cirq construction wraps it in a BloqAsCirqGate; which ends up being slow. See the first profiling picture in my linked comment.

yield cirq.inverse(multi_controlled_and)
context.qubit_manager.qfree(list(junk.flatten()) + [and_target])

Expand All @@ -178,7 +178,7 @@ def nth_operation(
) -> Iterator[cirq.OP_TREE]:
selection_idx = tuple(kwargs[reg.name] for reg in self.selection_registers)
target_regs = {reg.name: kwargs[reg.name] for reg in self.target_registers}
yield self._load_nth_data(selection_idx, lambda q: CNOT().on(control, q), **target_regs)
yield self._load_nth_data(selection_idx, lambda q: cirq.CNOT(control, q), **target_regs)

def _circuit_diagram_info_(self, args) -> cirq.CircuitDiagramInfo:
from qualtran.cirq_interop._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info
Expand Down
6 changes: 3 additions & 3 deletions qualtran/bloqs/multiplexers/unary_iteration_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _unary_iteration_segtree(
yield from _unary_iteration_segtree(
ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early
)
ops.append(CNOT().on(control, anc))
ops.append(cirq.CNOT(control, anc))
yield from _unary_iteration_segtree(
ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early
)
Expand All @@ -130,11 +130,11 @@ def _unary_iteration_zero_control(
ops, selection[1:], ancilla, l_iter, r_iter, break_early
)
return
ops.append(XGate().on(selection[0]))
ops.append(cirq.X(selection[0]))
yield from _unary_iteration_segtree(
ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter, break_early
)
ops.append(XGate().on(selection[0]))
ops.append(cirq.X(selection[0]))
yield from _unary_iteration_segtree(
ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter, break_early
)
Expand Down
76 changes: 40 additions & 36 deletions qualtran/cirq_interop/_cirq_to_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import abc
import itertools
import numbers
from functools import cached_property
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypeVar, Union
from functools import cache, cached_property
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, TypeVar, Union

import cirq
import numpy as np
Expand Down Expand Up @@ -342,49 +342,23 @@ def _gather_input_soqs(
return qvars_in


def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq:
from qualtran import Adjoint
@cache
def _cirq_gate_to_bloq_map() -> Dict[cirq.Gate, Bloq]:
# Check specific basic gates instances.
from qualtran.bloqs.basic_gates import (
CNOT,
CSwap,
CZPowGate,
GlobalPhase,
Hadamard,
Rx,
Ry,
Rz,
SGate,
TGate,
Toffoli,
TwoBitSwap,
XGate,
XPowGate,
YGate,
YPowGate,
ZGate,
ZPowGate,
)
from qualtran.cirq_interop import CirqGateAsBloq
from qualtran.cirq_interop._bloq_to_cirq import BloqAsCirqGate

if isinstance(gate, BloqAsCirqGate):
# Perhaps this operation was constructed from `Bloq.on()`.
return gate.bloq
if isinstance(gate, Bloq):
# I.e., `GateWithRegisters`.
return gate

if isinstance(gate, cirq.ops.raw_types._InverseCompositeGate):
# Inverse of a cirq gate, delegate to Adjoint
return Adjoint(_cirq_gate_to_bloq(gate._original))

if isinstance(gate, cirq.ControlledGate):
return Controlled(
_cirq_gate_to_bloq(gate.sub_gate), CtrlSpec.from_cirq_cv(gate.control_values)
)

# Check specific basic gates instances.
CIRQ_GATE_TO_BLOQ_MAP = {
return {
cirq.T: TGate(),
cirq.T**-1: TGate().adjoint(),
cirq.S: SGate(),
Expand All @@ -398,11 +372,13 @@ def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq:
cirq.SWAP: TwoBitSwap(),
cirq.CSWAP: CSwap(1),
}
if gate in CIRQ_GATE_TO_BLOQ_MAP:
return CIRQ_GATE_TO_BLOQ_MAP[gate]

# Check specific basic gates types.
CIRQ_TYPE_TO_BLOQ_MAP = {

@cache
def _cirq_type_to_bloq_map() -> Dict[Type[cirq.Gate], Type[Bloq]]:
from qualtran.bloqs.basic_gates import CZPowGate, Rx, Ry, Rz, XPowGate, YPowGate, ZPowGate

return {
cirq.Rz: Rz,
cirq.Rx: Rx,
cirq.Ry: Ry,
Expand All @@ -411,6 +387,34 @@ def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq:
cirq.ZPowGate: ZPowGate,
cirq.CZPowGate: CZPowGate,
}


def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq:
from qualtran import Adjoint
from qualtran.bloqs.basic_gates import GlobalPhase
from qualtran.cirq_interop import CirqGateAsBloq
from qualtran.cirq_interop._bloq_to_cirq import BloqAsCirqGate

if isinstance(gate, BloqAsCirqGate):
# Perhaps this operation was constructed from `Bloq.on()`.
return gate.bloq
if isinstance(gate, Bloq):
# I.e., `GateWithRegisters`.
return gate

if isinstance(gate, cirq.ops.raw_types._InverseCompositeGate):
# Inverse of a cirq gate, delegate to Adjoint
return Adjoint(_cirq_gate_to_bloq(gate._original))

if isinstance(gate, cirq.ControlledGate):
return Controlled(
_cirq_gate_to_bloq(gate.sub_gate), CtrlSpec.from_cirq_cv(gate.control_values)
)
CIRQ_GATE_TO_BLOQ_MAP = _cirq_gate_to_bloq_map()
if gate in CIRQ_GATE_TO_BLOQ_MAP:
return CIRQ_GATE_TO_BLOQ_MAP[gate]
CIRQ_TYPE_TO_BLOQ_MAP = _cirq_type_to_bloq_map()
# Check specific basic gates types.
if isinstance(gate, (cirq.Rx, cirq.Ry, cirq.Rz)):
return CIRQ_TYPE_TO_BLOQ_MAP[gate.__class__](angle=gate._rads)

Expand Down
Loading