Skip to content

Commit

Permalink
Fix symbolic call graphs for factoring phase estimates (#1497)
Browse files Browse the repository at this point in the history
* Add some symbolic decomp type errors to build call graphs

* Fixed call graph for FindECCPrivateKey and added typechecking for symbolics

* fix comparison bitsize error
  • Loading branch information
fpapa250 authored Nov 18, 2024
1 parent e10560c commit 6c8e772
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 62 deletions.
8 changes: 8 additions & 0 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,9 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
def build_composite_bloq(
self, bb: 'BloqBuilder', x: 'Soquet', y: 'Soquet', target: 'Soquet'
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.bitsize):
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")

cvs: Union[list[int], HasLength]
if isinstance(self.bitsize, int):
cvs = [0] * self.bitsize
Expand Down Expand Up @@ -1151,6 +1154,8 @@ def wire_symbol(
def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'Soquet', a: 'Soquet', b: 'Soquet', target: 'Soquet'
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.dtype.bitsize):
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")

if isinstance(self.dtype, QInt):
a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=a)
Expand Down Expand Up @@ -1360,6 +1365,9 @@ def build_composite_bloq(
c: Optional['Soquet'] = None,
target: Optional['Soquet'] = None,
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.dtype.bitsize):
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")

if self.uncompute:
# Uncompute
assert c is not None
Expand Down
5 changes: 5 additions & 0 deletions qualtran/bloqs/arithmetic/controlled_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
bloq_example,
BloqBuilder,
BloqDocSpec,
DecomposeTypeError,
QBit,
QInt,
QUInt,
Expand All @@ -37,6 +38,7 @@
from qualtran.bloqs.mcmt.and_bloq import And
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.simulation.classical_sim import add_ints
from qualtran.symbolics.types import is_symbolic

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand Down Expand Up @@ -134,6 +136,9 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'Soquet', a: 'Soquet', b: 'Soquet'
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.a_dtype.bitsize, self.b_dtype.bitsize):
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")

a_arr = bb.split(a)
ctrl_q = bb.split(ctrl)[0]
ancilla_arr = []
Expand Down
39 changes: 34 additions & 5 deletions qualtran/bloqs/factoring/_factoring_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,29 @@
# limitations under the License.

from functools import cached_property
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import numpy as np
import sympy
from attrs import frozen

from qualtran import Bloq, CompositeBloq, DecomposeTypeError, QBit, Register, Side, Signature
from qualtran import (
Bloq,
BloqBuilder,
DecomposeTypeError,
QBit,
QUInt,
Register,
Side,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates._shims import Measure
from qualtran.bloqs.qft import QFTTextBook
from qualtran.drawing import RarrowTextBox, Text, WireSymbol
from qualtran.symbolics import SymbolicInt
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.symbolics.types import SymbolicInt


@frozen
Expand All @@ -30,8 +46,21 @@ class MeasureQFT(Bloq):
def signature(self) -> 'Signature':
return Signature([Register('x', QBit(), shape=(self.n,), side=Side.LEFT)])

def decompose_bloq(self) -> 'CompositeBloq':
raise DecomposeTypeError('MeasureQFT is a placeholder, atomic bloq.')
def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'SoquetT']:
if isinstance(self.n, sympy.Expr):
raise DecomposeTypeError("Cannot decompose symbolic `n`.")

x = bb.join(np.array(x), dtype=QUInt(self.n))
x = bb.add(QFTTextBook(self.n), q=x)
x = bb.split(x)

for i in range(self.n):
bb.add(Measure(), q=x[i])

return {}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {QFTTextBook(self.n): 1, Measure(): self.n}

def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
Expand Down
94 changes: 48 additions & 46 deletions qualtran/bloqs/factoring/ecc/ec_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics.types import HasLength, is_symbolic
from qualtran.symbolics.types import HasLength, is_symbolic, SymbolicInt

from .ec_point import ECPoint

Expand Down Expand Up @@ -80,8 +80,8 @@ class _ECAddStepOne(Bloq):
Fig 10.
"""

n: int
mod: int
n: 'SymbolicInt'
mod: 'SymbolicInt'

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -214,9 +214,9 @@ class _ECAddStepTwo(Bloq):
Fig 10.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -251,7 +251,9 @@ def on_classical_vals(
f1 = 0
else:
lam = QMontgomeryUInt(self.n).montgomery_product(
int(y), QMontgomeryUInt(self.n).montgomery_inverse(int(x), self.mod), self.mod
int(y),
QMontgomeryUInt(self.n).montgomery_inverse(int(x), int(self.mod)),
int(self.mod),
)
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit
# which flips f1 when lam and lam_r are equal.
Expand Down Expand Up @@ -299,7 +301,7 @@ def build_composite_bloq(
# If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p.
z4_split = bb.split(z4)
lam_split = bb.split(lam)
for i in range(self.n):
for i in range(int(self.n)):
ctrls = [f1, ctrl, z4_split[i]]
ctrls, lam_split[i] = bb.add(
MultiControlX(cvs=[0, 1, 1]), controls=ctrls, target=lam_split[i]
Expand All @@ -311,7 +313,7 @@ def build_composite_bloq(

# If ctrl = 1 and x = a: lam = lam_r.
lam_r_split = bb.split(lam_r)
for i in range(self.n):
for i in range(int(self.n)):
ctrls = [f1, ctrl, lam_r_split[i]]
ctrls, lam_split[i] = bb.add(
MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i]
Expand Down Expand Up @@ -383,9 +385,9 @@ class _ECAddStepThree(Bloq):
Fig 10.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -455,7 +457,7 @@ def build_composite_bloq(
z1 = bb.add(IntState(bitsize=self.n, val=0))
a_split = bb.split(a)
z1_split = bb.split(z1)
for i in range(self.n):
for i in range(int(self.n)):
a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i])
a = bb.join(a_split, QMontgomeryUInt(self.n))
z1 = bb.join(z1_split, QMontgomeryUInt(self.n))
Expand All @@ -472,7 +474,7 @@ def build_composite_bloq(
z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(), x=z1)
a_split = bb.split(a)
z1_split = bb.split(z1)
for i in range(self.n):
for i in range(int(self.n)):
a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i])
a = bb.join(a_split, QMontgomeryUInt(self.n))
z1 = bb.join(z1_split, QMontgomeryUInt(self.n))
Expand Down Expand Up @@ -520,9 +522,9 @@ class _ECAddStepFour(Bloq):
Fig 10.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand All @@ -538,10 +540,10 @@ def on_classical_vals(
self, x: 'ClassicalValT', y: 'ClassicalValT', lam: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
x = (
x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), self.mod)
x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), int(self.mod))
) % self.mod
if lam > 0:
y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), self.mod)
y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), int(self.mod))
return {'x': x, 'y': y, 'lam': lam}

def build_composite_bloq(
Expand All @@ -554,7 +556,7 @@ def build_composite_bloq(
z4 = bb.add(IntState(bitsize=self.n, val=0))
lam_split = bb.split(lam)
z4_split = bb.split(z4)
for i in range(self.n):
for i in range(int(self.n)):
lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i])
lam = bb.join(lam_split, QMontgomeryUInt(self.n))
z4 = bb.join(z4_split, QMontgomeryUInt(self.n))
Expand Down Expand Up @@ -584,7 +586,7 @@ def build_composite_bloq(
)
lam_split = bb.split(lam)
z4_split = bb.split(z4)
for i in range(self.n):
for i in range(int(self.n)):
lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i])
lam = bb.join(lam_split, QMontgomeryUInt(self.n))
z4 = bb.join(z4_split, QMontgomeryUInt(self.n))
Expand All @@ -602,7 +604,7 @@ def build_composite_bloq(
# y = y_r + b % p.
z3_split = bb.split(z3)
y_split = bb.split(y)
for i in range(self.n):
for i in range(int(self.n)):
z3_split[i], y_split[i] = bb.add(CNOT(), ctrl=z3_split[i], target=y_split[i])
z3 = bb.join(z3_split, QMontgomeryUInt(self.n))
y = bb.join(y_split, QMontgomeryUInt(self.n))
Expand Down Expand Up @@ -659,9 +661,9 @@ class _ECAddStepFive(Bloq):
Fig 10.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -720,7 +722,7 @@ def build_composite_bloq(
# If ctrl: lam = 0.
z4_split = bb.split(z4)
lam_split = bb.split(lam)
for i in range(self.n):
for i in range(int(self.n)):
ctrls = [ctrl, z4_split[i]]
ctrls, lam_split[i] = bb.add(
MultiControlX(cvs=[1, 1]), controls=ctrls, target=lam_split[i]
Expand Down Expand Up @@ -801,8 +803,8 @@ class _ECAddStepSix(Bloq):
Fig 10.
"""

n: int
mod: int
n: 'SymbolicInt'
mod: 'SymbolicInt'

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -866,7 +868,7 @@ def build_composite_bloq(
# Set (x, y) to (a, b) if f4 is set.
a_split = bb.split(a)
x_split = bb.split(x)
for i in range(self.n):
for i in range(int(self.n)):
toff_ctrl = [f4, a_split[i]]
toff_ctrl, x_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=x_split[i])
f4 = toff_ctrl[0]
Expand All @@ -875,7 +877,7 @@ def build_composite_bloq(
x = bb.join(x_split, QMontgomeryUInt(self.n))
b_split = bb.split(b)
y_split = bb.split(y)
for i in range(self.n):
for i in range(int(self.n)):
toff_ctrl = [f4, b_split[i]]
toff_ctrl, y_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=y_split[i])
f4 = toff_ctrl[0]
Expand All @@ -888,11 +890,11 @@ def build_composite_bloq(
xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n))
ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4)
ab_split = bb.split(ab)
a = bb.join(ab_split[: self.n], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_split[self.n :], dtype=QMontgomeryUInt(self.n))
a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))
xy_split = bb.split(xy)
x = bb.join(xy_split[: self.n], dtype=QMontgomeryUInt(self.n))
y = bb.join(xy_split[self.n :], dtype=QMontgomeryUInt(self.n))
x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))

# Unset f3 if (a, b) = (0, 0).
ab_arr = np.concatenate([bb.split(a), bb.split(b)])
Expand Down Expand Up @@ -1000,9 +1002,9 @@ class ECAdd(Bloq):
Litinski. 2023. Fig 5.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -1070,29 +1072,29 @@ def build_composite_bloq(

def on_classical_vals(self, a, b, x, y, lam_r) -> Dict[str, Union['ClassicalValT', sympy.Expr]]:
curve_a = (
QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, self.mod)
QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, int(self.mod))
* 2
* QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod)
- (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod) ** 2)
* QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod))
- (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)) ** 2)
) % self.mod
p1 = ECPoint(
QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod),
QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod),
QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)),
QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod)),
mod=self.mod,
curve_a=curve_a,
)
p2 = ECPoint(
QMontgomeryUInt(self.n).montgomery_to_uint(x, self.mod),
QMontgomeryUInt(self.n).montgomery_to_uint(y, self.mod),
QMontgomeryUInt(self.n).montgomery_to_uint(x, int(self.mod)),
QMontgomeryUInt(self.n).montgomery_to_uint(y, int(self.mod)),
mod=self.mod,
curve_a=curve_a,
)
result = p1 + p2
return {
'a': a,
'b': b,
'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, self.mod),
'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, self.mod),
'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, int(self.mod)),
'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, int(self.mod)),
'lam_r': lam_r,
}

Expand Down
Loading

0 comments on commit 6c8e772

Please sign in to comment.