Skip to content

Commit 6c8e772

Browse files
authored
Fix symbolic call graphs for factoring phase estimates (#1497)
* Add some symbolic decomp type errors to build call graphs * Fixed call graph for FindECCPrivateKey and added typechecking for symbolics * fix comparison bitsize error
1 parent e10560c commit 6c8e772

File tree

9 files changed

+120
-62
lines changed

9 files changed

+120
-62
lines changed

qualtran/bloqs/arithmetic/comparison.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,9 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
988988
def build_composite_bloq(
989989
self, bb: 'BloqBuilder', x: 'Soquet', y: 'Soquet', target: 'Soquet'
990990
) -> Dict[str, 'SoquetT']:
991+
if is_symbolic(self.bitsize):
992+
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")
993+
991994
cvs: Union[list[int], HasLength]
992995
if isinstance(self.bitsize, int):
993996
cvs = [0] * self.bitsize
@@ -1151,6 +1154,8 @@ def wire_symbol(
11511154
def build_composite_bloq(
11521155
self, bb: 'BloqBuilder', ctrl: 'Soquet', a: 'Soquet', b: 'Soquet', target: 'Soquet'
11531156
) -> Dict[str, 'SoquetT']:
1157+
if is_symbolic(self.dtype.bitsize):
1158+
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")
11541159

11551160
if isinstance(self.dtype, QInt):
11561161
a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=a)
@@ -1360,6 +1365,9 @@ def build_composite_bloq(
13601365
c: Optional['Soquet'] = None,
13611366
target: Optional['Soquet'] = None,
13621367
) -> Dict[str, 'SoquetT']:
1368+
if is_symbolic(self.dtype.bitsize):
1369+
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")
1370+
13631371
if self.uncompute:
13641372
# Uncompute
13651373
assert c is not None

qualtran/bloqs/arithmetic/controlled_addition.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
bloq_example,
2424
BloqBuilder,
2525
BloqDocSpec,
26+
DecomposeTypeError,
2627
QBit,
2728
QInt,
2829
QUInt,
@@ -37,6 +38,7 @@
3738
from qualtran.bloqs.mcmt.and_bloq import And
3839
from qualtran.resource_counting.generalizers import ignore_split_join
3940
from qualtran.simulation.classical_sim import add_ints
41+
from qualtran.symbolics.types import is_symbolic
4042

4143
if TYPE_CHECKING:
4244
import quimb.tensor as qtn
@@ -134,6 +136,9 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
134136
def build_composite_bloq(
135137
self, bb: 'BloqBuilder', ctrl: 'Soquet', a: 'Soquet', b: 'Soquet'
136138
) -> Dict[str, 'SoquetT']:
139+
if is_symbolic(self.a_dtype.bitsize, self.b_dtype.bitsize):
140+
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")
141+
137142
a_arr = bb.split(a)
138143
ctrl_q = bb.split(ctrl)[0]
139144
ancilla_arr = []

qualtran/bloqs/factoring/_factoring_shims.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,29 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Optional, Tuple
16+
from typing import Dict, Optional, Tuple
1717

18+
import numpy as np
19+
import sympy
1820
from attrs import frozen
1921

20-
from qualtran import Bloq, CompositeBloq, DecomposeTypeError, QBit, Register, Side, Signature
22+
from qualtran import (
23+
Bloq,
24+
BloqBuilder,
25+
DecomposeTypeError,
26+
QBit,
27+
QUInt,
28+
Register,
29+
Side,
30+
Signature,
31+
Soquet,
32+
SoquetT,
33+
)
34+
from qualtran.bloqs.basic_gates._shims import Measure
35+
from qualtran.bloqs.qft import QFTTextBook
2136
from qualtran.drawing import RarrowTextBox, Text, WireSymbol
22-
from qualtran.symbolics import SymbolicInt
37+
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
38+
from qualtran.symbolics.types import SymbolicInt
2339

2440

2541
@frozen
@@ -30,8 +46,21 @@ class MeasureQFT(Bloq):
3046
def signature(self) -> 'Signature':
3147
return Signature([Register('x', QBit(), shape=(self.n,), side=Side.LEFT)])
3248

33-
def decompose_bloq(self) -> 'CompositeBloq':
34-
raise DecomposeTypeError('MeasureQFT is a placeholder, atomic bloq.')
49+
def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'SoquetT']:
50+
if isinstance(self.n, sympy.Expr):
51+
raise DecomposeTypeError("Cannot decompose symbolic `n`.")
52+
53+
x = bb.join(np.array(x), dtype=QUInt(self.n))
54+
x = bb.add(QFTTextBook(self.n), q=x)
55+
x = bb.split(x)
56+
57+
for i in range(self.n):
58+
bb.add(Measure(), q=x[i])
59+
60+
return {}
61+
62+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
63+
return {QFTTextBook(self.n): 1, Measure(): self.n}
3564

3665
def wire_symbol(
3766
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()

qualtran/bloqs/factoring/ecc/ec_add.py

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
)
5050
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
5151
from qualtran.simulation.classical_sim import ClassicalValT
52-
from qualtran.symbolics.types import HasLength, is_symbolic
52+
from qualtran.symbolics.types import HasLength, is_symbolic, SymbolicInt
5353

5454
from .ec_point import ECPoint
5555

@@ -80,8 +80,8 @@ class _ECAddStepOne(Bloq):
8080
Fig 10.
8181
"""
8282

83-
n: int
84-
mod: int
83+
n: 'SymbolicInt'
84+
mod: 'SymbolicInt'
8585

8686
@cached_property
8787
def signature(self) -> 'Signature':
@@ -214,9 +214,9 @@ class _ECAddStepTwo(Bloq):
214214
Fig 10.
215215
"""
216216

217-
n: int
218-
mod: int
219-
window_size: int = 1
217+
n: 'SymbolicInt'
218+
mod: 'SymbolicInt'
219+
window_size: 'SymbolicInt' = 1
220220

221221
@cached_property
222222
def signature(self) -> 'Signature':
@@ -251,7 +251,9 @@ def on_classical_vals(
251251
f1 = 0
252252
else:
253253
lam = QMontgomeryUInt(self.n).montgomery_product(
254-
int(y), QMontgomeryUInt(self.n).montgomery_inverse(int(x), self.mod), self.mod
254+
int(y),
255+
QMontgomeryUInt(self.n).montgomery_inverse(int(x), int(self.mod)),
256+
int(self.mod),
255257
)
256258
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit
257259
# which flips f1 when lam and lam_r are equal.
@@ -299,7 +301,7 @@ def build_composite_bloq(
299301
# If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p.
300302
z4_split = bb.split(z4)
301303
lam_split = bb.split(lam)
302-
for i in range(self.n):
304+
for i in range(int(self.n)):
303305
ctrls = [f1, ctrl, z4_split[i]]
304306
ctrls, lam_split[i] = bb.add(
305307
MultiControlX(cvs=[0, 1, 1]), controls=ctrls, target=lam_split[i]
@@ -311,7 +313,7 @@ def build_composite_bloq(
311313

312314
# If ctrl = 1 and x = a: lam = lam_r.
313315
lam_r_split = bb.split(lam_r)
314-
for i in range(self.n):
316+
for i in range(int(self.n)):
315317
ctrls = [f1, ctrl, lam_r_split[i]]
316318
ctrls, lam_split[i] = bb.add(
317319
MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i]
@@ -383,9 +385,9 @@ class _ECAddStepThree(Bloq):
383385
Fig 10.
384386
"""
385387

386-
n: int
387-
mod: int
388-
window_size: int = 1
388+
n: 'SymbolicInt'
389+
mod: 'SymbolicInt'
390+
window_size: 'SymbolicInt' = 1
389391

390392
@cached_property
391393
def signature(self) -> 'Signature':
@@ -455,7 +457,7 @@ def build_composite_bloq(
455457
z1 = bb.add(IntState(bitsize=self.n, val=0))
456458
a_split = bb.split(a)
457459
z1_split = bb.split(z1)
458-
for i in range(self.n):
460+
for i in range(int(self.n)):
459461
a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i])
460462
a = bb.join(a_split, QMontgomeryUInt(self.n))
461463
z1 = bb.join(z1_split, QMontgomeryUInt(self.n))
@@ -472,7 +474,7 @@ def build_composite_bloq(
472474
z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(), x=z1)
473475
a_split = bb.split(a)
474476
z1_split = bb.split(z1)
475-
for i in range(self.n):
477+
for i in range(int(self.n)):
476478
a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i])
477479
a = bb.join(a_split, QMontgomeryUInt(self.n))
478480
z1 = bb.join(z1_split, QMontgomeryUInt(self.n))
@@ -520,9 +522,9 @@ class _ECAddStepFour(Bloq):
520522
Fig 10.
521523
"""
522524

523-
n: int
524-
mod: int
525-
window_size: int = 1
525+
n: 'SymbolicInt'
526+
mod: 'SymbolicInt'
527+
window_size: 'SymbolicInt' = 1
526528

527529
@cached_property
528530
def signature(self) -> 'Signature':
@@ -538,10 +540,10 @@ def on_classical_vals(
538540
self, x: 'ClassicalValT', y: 'ClassicalValT', lam: 'ClassicalValT'
539541
) -> Dict[str, 'ClassicalValT']:
540542
x = (
541-
x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), self.mod)
543+
x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), int(self.mod))
542544
) % self.mod
543545
if lam > 0:
544-
y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), self.mod)
546+
y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), int(self.mod))
545547
return {'x': x, 'y': y, 'lam': lam}
546548

547549
def build_composite_bloq(
@@ -554,7 +556,7 @@ def build_composite_bloq(
554556
z4 = bb.add(IntState(bitsize=self.n, val=0))
555557
lam_split = bb.split(lam)
556558
z4_split = bb.split(z4)
557-
for i in range(self.n):
559+
for i in range(int(self.n)):
558560
lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i])
559561
lam = bb.join(lam_split, QMontgomeryUInt(self.n))
560562
z4 = bb.join(z4_split, QMontgomeryUInt(self.n))
@@ -584,7 +586,7 @@ def build_composite_bloq(
584586
)
585587
lam_split = bb.split(lam)
586588
z4_split = bb.split(z4)
587-
for i in range(self.n):
589+
for i in range(int(self.n)):
588590
lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i])
589591
lam = bb.join(lam_split, QMontgomeryUInt(self.n))
590592
z4 = bb.join(z4_split, QMontgomeryUInt(self.n))
@@ -602,7 +604,7 @@ def build_composite_bloq(
602604
# y = y_r + b % p.
603605
z3_split = bb.split(z3)
604606
y_split = bb.split(y)
605-
for i in range(self.n):
607+
for i in range(int(self.n)):
606608
z3_split[i], y_split[i] = bb.add(CNOT(), ctrl=z3_split[i], target=y_split[i])
607609
z3 = bb.join(z3_split, QMontgomeryUInt(self.n))
608610
y = bb.join(y_split, QMontgomeryUInt(self.n))
@@ -659,9 +661,9 @@ class _ECAddStepFive(Bloq):
659661
Fig 10.
660662
"""
661663

662-
n: int
663-
mod: int
664-
window_size: int = 1
664+
n: 'SymbolicInt'
665+
mod: 'SymbolicInt'
666+
window_size: 'SymbolicInt' = 1
665667

666668
@cached_property
667669
def signature(self) -> 'Signature':
@@ -720,7 +722,7 @@ def build_composite_bloq(
720722
# If ctrl: lam = 0.
721723
z4_split = bb.split(z4)
722724
lam_split = bb.split(lam)
723-
for i in range(self.n):
725+
for i in range(int(self.n)):
724726
ctrls = [ctrl, z4_split[i]]
725727
ctrls, lam_split[i] = bb.add(
726728
MultiControlX(cvs=[1, 1]), controls=ctrls, target=lam_split[i]
@@ -801,8 +803,8 @@ class _ECAddStepSix(Bloq):
801803
Fig 10.
802804
"""
803805

804-
n: int
805-
mod: int
806+
n: 'SymbolicInt'
807+
mod: 'SymbolicInt'
806808

807809
@cached_property
808810
def signature(self) -> 'Signature':
@@ -866,7 +868,7 @@ def build_composite_bloq(
866868
# Set (x, y) to (a, b) if f4 is set.
867869
a_split = bb.split(a)
868870
x_split = bb.split(x)
869-
for i in range(self.n):
871+
for i in range(int(self.n)):
870872
toff_ctrl = [f4, a_split[i]]
871873
toff_ctrl, x_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=x_split[i])
872874
f4 = toff_ctrl[0]
@@ -875,7 +877,7 @@ def build_composite_bloq(
875877
x = bb.join(x_split, QMontgomeryUInt(self.n))
876878
b_split = bb.split(b)
877879
y_split = bb.split(y)
878-
for i in range(self.n):
880+
for i in range(int(self.n)):
879881
toff_ctrl = [f4, b_split[i]]
880882
toff_ctrl, y_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=y_split[i])
881883
f4 = toff_ctrl[0]
@@ -888,11 +890,11 @@ def build_composite_bloq(
888890
xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n))
889891
ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4)
890892
ab_split = bb.split(ab)
891-
a = bb.join(ab_split[: self.n], dtype=QMontgomeryUInt(self.n))
892-
b = bb.join(ab_split[self.n :], dtype=QMontgomeryUInt(self.n))
893+
a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
894+
b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))
893895
xy_split = bb.split(xy)
894-
x = bb.join(xy_split[: self.n], dtype=QMontgomeryUInt(self.n))
895-
y = bb.join(xy_split[self.n :], dtype=QMontgomeryUInt(self.n))
896+
x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
897+
y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))
896898

897899
# Unset f3 if (a, b) = (0, 0).
898900
ab_arr = np.concatenate([bb.split(a), bb.split(b)])
@@ -1000,9 +1002,9 @@ class ECAdd(Bloq):
10001002
Litinski. 2023. Fig 5.
10011003
"""
10021004

1003-
n: int
1004-
mod: int
1005-
window_size: int = 1
1005+
n: 'SymbolicInt'
1006+
mod: 'SymbolicInt'
1007+
window_size: 'SymbolicInt' = 1
10061008

10071009
@cached_property
10081010
def signature(self) -> 'Signature':
@@ -1070,29 +1072,29 @@ def build_composite_bloq(
10701072

10711073
def on_classical_vals(self, a, b, x, y, lam_r) -> Dict[str, Union['ClassicalValT', sympy.Expr]]:
10721074
curve_a = (
1073-
QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, self.mod)
1075+
QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, int(self.mod))
10741076
* 2
1075-
* QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod)
1076-
- (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod) ** 2)
1077+
* QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod))
1078+
- (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)) ** 2)
10771079
) % self.mod
10781080
p1 = ECPoint(
1079-
QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod),
1080-
QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod),
1081+
QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)),
1082+
QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod)),
10811083
mod=self.mod,
10821084
curve_a=curve_a,
10831085
)
10841086
p2 = ECPoint(
1085-
QMontgomeryUInt(self.n).montgomery_to_uint(x, self.mod),
1086-
QMontgomeryUInt(self.n).montgomery_to_uint(y, self.mod),
1087+
QMontgomeryUInt(self.n).montgomery_to_uint(x, int(self.mod)),
1088+
QMontgomeryUInt(self.n).montgomery_to_uint(y, int(self.mod)),
10871089
mod=self.mod,
10881090
curve_a=curve_a,
10891091
)
10901092
result = p1 + p2
10911093
return {
10921094
'a': a,
10931095
'b': b,
1094-
'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, self.mod),
1095-
'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, self.mod),
1096+
'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, int(self.mod)),
1097+
'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, int(self.mod)),
10961098
'lam_r': lam_r,
10971099
}
10981100

0 commit comments

Comments
 (0)