From 6c8e772a31dc4e86865b449a8ceb3fca53332dae Mon Sep 17 00:00:00 2001 From: Frankie Papa Date: Mon, 18 Nov 2024 10:19:22 -0800 Subject: [PATCH] Fix symbolic call graphs for factoring phase estimates (#1497) * Add some symbolic decomp type errors to build call graphs * Fixed call graph for FindECCPrivateKey and added typechecking for symbolics * fix comparison bitsize error --- qualtran/bloqs/arithmetic/comparison.py | 8 ++ .../bloqs/arithmetic/controlled_addition.py | 5 + qualtran/bloqs/factoring/_factoring_shims.py | 39 +++++++- qualtran/bloqs/factoring/ecc/ec_add.py | 94 ++++++++++--------- qualtran/bloqs/factoring/ecc/ec_add_r.py | 10 +- .../factoring/ecc/ec_phase_estimate_r.py | 7 +- .../factoring/ecc/find_ecc_private_key.py | 6 +- qualtran/bloqs/mod_arithmetic/mod_division.py | 9 ++ .../mod_arithmetic/mod_multiplication.py | 4 + 9 files changed, 120 insertions(+), 62 deletions(-) diff --git a/qualtran/bloqs/arithmetic/comparison.py b/qualtran/bloqs/arithmetic/comparison.py index fff30d5fc..a47d95cff 100644 --- a/qualtran/bloqs/arithmetic/comparison.py +++ b/qualtran/bloqs/arithmetic/comparison.py @@ -988,6 +988,9 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) - def build_composite_bloq( self, bb: 'BloqBuilder', x: 'Soquet', y: 'Soquet', target: 'Soquet' ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") + cvs: Union[list[int], HasLength] if isinstance(self.bitsize, int): cvs = [0] * self.bitsize @@ -1151,6 +1154,8 @@ def wire_symbol( def build_composite_bloq( self, bb: 'BloqBuilder', ctrl: 'Soquet', a: 'Soquet', b: 'Soquet', target: 'Soquet' ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.dtype.bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") if isinstance(self.dtype, QInt): a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=a) @@ -1360,6 +1365,9 @@ def build_composite_bloq( c: Optional['Soquet'] = None, target: Optional['Soquet'] = None, ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.dtype.bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") + if self.uncompute: # Uncompute assert c is not None diff --git a/qualtran/bloqs/arithmetic/controlled_addition.py b/qualtran/bloqs/arithmetic/controlled_addition.py index cfe4bcc5d..dc6751d23 100644 --- a/qualtran/bloqs/arithmetic/controlled_addition.py +++ b/qualtran/bloqs/arithmetic/controlled_addition.py @@ -23,6 +23,7 @@ bloq_example, BloqBuilder, BloqDocSpec, + DecomposeTypeError, QBit, QInt, QUInt, @@ -37,6 +38,7 @@ from qualtran.bloqs.mcmt.and_bloq import And from qualtran.resource_counting.generalizers import ignore_split_join from qualtran.simulation.classical_sim import add_ints +from qualtran.symbolics.types import is_symbolic if TYPE_CHECKING: import quimb.tensor as qtn @@ -134,6 +136,9 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': def build_composite_bloq( self, bb: 'BloqBuilder', ctrl: 'Soquet', a: 'Soquet', b: 'Soquet' ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.a_dtype.bitsize, self.b_dtype.bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") + a_arr = bb.split(a) ctrl_q = bb.split(ctrl)[0] ancilla_arr = [] diff --git a/qualtran/bloqs/factoring/_factoring_shims.py b/qualtran/bloqs/factoring/_factoring_shims.py index 896e103b8..1bc32a5c6 100644 --- a/qualtran/bloqs/factoring/_factoring_shims.py +++ b/qualtran/bloqs/factoring/_factoring_shims.py @@ -13,13 +13,29 @@ # limitations under the License. from functools import cached_property -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple +import numpy as np +import sympy from attrs import frozen -from qualtran import Bloq, CompositeBloq, DecomposeTypeError, QBit, Register, Side, Signature +from qualtran import ( + Bloq, + BloqBuilder, + DecomposeTypeError, + QBit, + QUInt, + Register, + Side, + Signature, + Soquet, + SoquetT, +) +from qualtran.bloqs.basic_gates._shims import Measure +from qualtran.bloqs.qft import QFTTextBook from qualtran.drawing import RarrowTextBox, Text, WireSymbol -from qualtran.symbolics import SymbolicInt +from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.symbolics.types import SymbolicInt @frozen @@ -30,8 +46,21 @@ class MeasureQFT(Bloq): def signature(self) -> 'Signature': return Signature([Register('x', QBit(), shape=(self.n,), side=Side.LEFT)]) - def decompose_bloq(self) -> 'CompositeBloq': - raise DecomposeTypeError('MeasureQFT is a placeholder, atomic bloq.') + def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'SoquetT']: + if isinstance(self.n, sympy.Expr): + raise DecomposeTypeError("Cannot decompose symbolic `n`.") + + x = bb.join(np.array(x), dtype=QUInt(self.n)) + x = bb.add(QFTTextBook(self.n), q=x) + x = bb.split(x) + + for i in range(self.n): + bb.add(Measure(), q=x[i]) + + return {} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + return {QFTTextBook(self.n): 1, Measure(): self.n} def wire_symbol( self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() diff --git a/qualtran/bloqs/factoring/ecc/ec_add.py b/qualtran/bloqs/factoring/ecc/ec_add.py index 572e47673..f764c70cb 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.py +++ b/qualtran/bloqs/factoring/ecc/ec_add.py @@ -49,7 +49,7 @@ ) from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT -from qualtran.symbolics.types import HasLength, is_symbolic +from qualtran.symbolics.types import HasLength, is_symbolic, SymbolicInt from .ec_point import ECPoint @@ -80,8 +80,8 @@ class _ECAddStepOne(Bloq): Fig 10. """ - n: int - mod: int + n: 'SymbolicInt' + mod: 'SymbolicInt' @cached_property def signature(self) -> 'Signature': @@ -214,9 +214,9 @@ class _ECAddStepTwo(Bloq): Fig 10. """ - n: int - mod: int - window_size: int = 1 + n: 'SymbolicInt' + mod: 'SymbolicInt' + window_size: 'SymbolicInt' = 1 @cached_property def signature(self) -> 'Signature': @@ -251,7 +251,9 @@ def on_classical_vals( f1 = 0 else: lam = QMontgomeryUInt(self.n).montgomery_product( - int(y), QMontgomeryUInt(self.n).montgomery_inverse(int(x), self.mod), self.mod + int(y), + QMontgomeryUInt(self.n).montgomery_inverse(int(x), int(self.mod)), + int(self.mod), ) # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit # which flips f1 when lam and lam_r are equal. @@ -299,7 +301,7 @@ def build_composite_bloq( # If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p. z4_split = bb.split(z4) lam_split = bb.split(lam) - for i in range(self.n): + for i in range(int(self.n)): ctrls = [f1, ctrl, z4_split[i]] ctrls, lam_split[i] = bb.add( MultiControlX(cvs=[0, 1, 1]), controls=ctrls, target=lam_split[i] @@ -311,7 +313,7 @@ def build_composite_bloq( # If ctrl = 1 and x = a: lam = lam_r. lam_r_split = bb.split(lam_r) - for i in range(self.n): + for i in range(int(self.n)): ctrls = [f1, ctrl, lam_r_split[i]] ctrls, lam_split[i] = bb.add( MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i] @@ -383,9 +385,9 @@ class _ECAddStepThree(Bloq): Fig 10. """ - n: int - mod: int - window_size: int = 1 + n: 'SymbolicInt' + mod: 'SymbolicInt' + window_size: 'SymbolicInt' = 1 @cached_property def signature(self) -> 'Signature': @@ -455,7 +457,7 @@ def build_composite_bloq( z1 = bb.add(IntState(bitsize=self.n, val=0)) a_split = bb.split(a) z1_split = bb.split(z1) - for i in range(self.n): + for i in range(int(self.n)): a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i]) a = bb.join(a_split, QMontgomeryUInt(self.n)) z1 = bb.join(z1_split, QMontgomeryUInt(self.n)) @@ -472,7 +474,7 @@ def build_composite_bloq( z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(), x=z1) a_split = bb.split(a) z1_split = bb.split(z1) - for i in range(self.n): + for i in range(int(self.n)): a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i]) a = bb.join(a_split, QMontgomeryUInt(self.n)) z1 = bb.join(z1_split, QMontgomeryUInt(self.n)) @@ -520,9 +522,9 @@ class _ECAddStepFour(Bloq): Fig 10. """ - n: int - mod: int - window_size: int = 1 + n: 'SymbolicInt' + mod: 'SymbolicInt' + window_size: 'SymbolicInt' = 1 @cached_property def signature(self) -> 'Signature': @@ -538,10 +540,10 @@ def on_classical_vals( self, x: 'ClassicalValT', y: 'ClassicalValT', lam: 'ClassicalValT' ) -> Dict[str, 'ClassicalValT']: x = ( - x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), self.mod) + x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), int(self.mod)) ) % self.mod if lam > 0: - y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), self.mod) + y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), int(self.mod)) return {'x': x, 'y': y, 'lam': lam} def build_composite_bloq( @@ -554,7 +556,7 @@ def build_composite_bloq( z4 = bb.add(IntState(bitsize=self.n, val=0)) lam_split = bb.split(lam) z4_split = bb.split(z4) - for i in range(self.n): + for i in range(int(self.n)): lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i]) lam = bb.join(lam_split, QMontgomeryUInt(self.n)) z4 = bb.join(z4_split, QMontgomeryUInt(self.n)) @@ -584,7 +586,7 @@ def build_composite_bloq( ) lam_split = bb.split(lam) z4_split = bb.split(z4) - for i in range(self.n): + for i in range(int(self.n)): lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i]) lam = bb.join(lam_split, QMontgomeryUInt(self.n)) z4 = bb.join(z4_split, QMontgomeryUInt(self.n)) @@ -602,7 +604,7 @@ def build_composite_bloq( # y = y_r + b % p. z3_split = bb.split(z3) y_split = bb.split(y) - for i in range(self.n): + for i in range(int(self.n)): z3_split[i], y_split[i] = bb.add(CNOT(), ctrl=z3_split[i], target=y_split[i]) z3 = bb.join(z3_split, QMontgomeryUInt(self.n)) y = bb.join(y_split, QMontgomeryUInt(self.n)) @@ -659,9 +661,9 @@ class _ECAddStepFive(Bloq): Fig 10. """ - n: int - mod: int - window_size: int = 1 + n: 'SymbolicInt' + mod: 'SymbolicInt' + window_size: 'SymbolicInt' = 1 @cached_property def signature(self) -> 'Signature': @@ -720,7 +722,7 @@ def build_composite_bloq( # If ctrl: lam = 0. z4_split = bb.split(z4) lam_split = bb.split(lam) - for i in range(self.n): + for i in range(int(self.n)): ctrls = [ctrl, z4_split[i]] ctrls, lam_split[i] = bb.add( MultiControlX(cvs=[1, 1]), controls=ctrls, target=lam_split[i] @@ -801,8 +803,8 @@ class _ECAddStepSix(Bloq): Fig 10. """ - n: int - mod: int + n: 'SymbolicInt' + mod: 'SymbolicInt' @cached_property def signature(self) -> 'Signature': @@ -866,7 +868,7 @@ def build_composite_bloq( # Set (x, y) to (a, b) if f4 is set. a_split = bb.split(a) x_split = bb.split(x) - for i in range(self.n): + for i in range(int(self.n)): toff_ctrl = [f4, a_split[i]] toff_ctrl, x_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=x_split[i]) f4 = toff_ctrl[0] @@ -875,7 +877,7 @@ def build_composite_bloq( x = bb.join(x_split, QMontgomeryUInt(self.n)) b_split = bb.split(b) y_split = bb.split(y) - for i in range(self.n): + for i in range(int(self.n)): toff_ctrl = [f4, b_split[i]] toff_ctrl, y_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=y_split[i]) f4 = toff_ctrl[0] @@ -888,11 +890,11 @@ def build_composite_bloq( xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n)) ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4) ab_split = bb.split(ab) - a = bb.join(ab_split[: self.n], dtype=QMontgomeryUInt(self.n)) - b = bb.join(ab_split[self.n :], dtype=QMontgomeryUInt(self.n)) + a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) xy_split = bb.split(xy) - x = bb.join(xy_split[: self.n], dtype=QMontgomeryUInt(self.n)) - y = bb.join(xy_split[self.n :], dtype=QMontgomeryUInt(self.n)) + x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) # Unset f3 if (a, b) = (0, 0). ab_arr = np.concatenate([bb.split(a), bb.split(b)]) @@ -1000,9 +1002,9 @@ class ECAdd(Bloq): Litinski. 2023. Fig 5. """ - n: int - mod: int - window_size: int = 1 + n: 'SymbolicInt' + mod: 'SymbolicInt' + window_size: 'SymbolicInt' = 1 @cached_property def signature(self) -> 'Signature': @@ -1070,20 +1072,20 @@ def build_composite_bloq( def on_classical_vals(self, a, b, x, y, lam_r) -> Dict[str, Union['ClassicalValT', sympy.Expr]]: curve_a = ( - QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, self.mod) + QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, int(self.mod)) * 2 - * QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod) - - (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod) ** 2) + * QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod)) + - (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)) ** 2) ) % self.mod p1 = ECPoint( - QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod), - QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod), + QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)), + QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod)), mod=self.mod, curve_a=curve_a, ) p2 = ECPoint( - QMontgomeryUInt(self.n).montgomery_to_uint(x, self.mod), - QMontgomeryUInt(self.n).montgomery_to_uint(y, self.mod), + QMontgomeryUInt(self.n).montgomery_to_uint(x, int(self.mod)), + QMontgomeryUInt(self.n).montgomery_to_uint(y, int(self.mod)), mod=self.mod, curve_a=curve_a, ) @@ -1091,8 +1093,8 @@ def on_classical_vals(self, a, b, x, y, lam_r) -> Dict[str, Union['ClassicalValT return { 'a': a, 'b': b, - 'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, self.mod), - 'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, self.mod), + 'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, int(self.mod)), + 'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, int(self.mod)), 'lam_r': lam_r, } diff --git a/qualtran/bloqs/factoring/ecc/ec_add_r.py b/qualtran/bloqs/factoring/ecc/ec_add_r.py index 9ced3a9f1..6bf020659 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add_r.py +++ b/qualtran/bloqs/factoring/ecc/ec_add_r.py @@ -36,7 +36,7 @@ from qualtran.drawing import Circle, Text, TextBox, WireSymbol from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT -from qualtran.symbolics import is_symbolic, Shaped +from qualtran.symbolics import is_symbolic, Shaped, SymbolicInt from .ec_add import ECAdd from .ec_point import ECPoint @@ -75,7 +75,7 @@ class ECAddR(Bloq): """ - n: int + n: 'SymbolicInt' R: ECPoint @cached_property @@ -144,10 +144,10 @@ class ECWindowAddR(Bloq): Litinski. 2013. Section 1, eq. (3) and (4). """ - n: int + n: 'SymbolicInt' R: ECPoint - add_window_size: int - mul_window_size: int = 1 + add_window_size: 'SymbolicInt' + mul_window_size: 'SymbolicInt' = 1 @cached_property def signature(self) -> 'Signature': diff --git a/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py b/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py index d56400d20..eb991c19a 100644 --- a/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py +++ b/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py @@ -34,6 +34,7 @@ ) from qualtran.bloqs.basic_gates import PlusState from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.symbolics.types import SymbolicInt from .._factoring_shims import MeasureQFT from .ec_add_r import ECAddR, ECWindowAddR @@ -58,10 +59,10 @@ class ECPhaseEstimateR(Bloq): mul_window_size: The number of bits in the modular multiplication window. """ - n: int + n: 'SymbolicInt' point: ECPoint - add_window_size: int = 1 - mul_window_size: int = 1 + add_window_size: 'SymbolicInt' = 1 + mul_window_size: 'SymbolicInt' = 1 @cached_property def signature(self) -> 'Signature': diff --git a/qualtran/bloqs/factoring/ecc/find_ecc_private_key.py b/qualtran/bloqs/factoring/ecc/find_ecc_private_key.py index 05785939f..f6129af80 100644 --- a/qualtran/bloqs/factoring/ecc/find_ecc_private_key.py +++ b/qualtran/bloqs/factoring/ecc/find_ecc_private_key.py @@ -75,11 +75,11 @@ class FindECCPrivateKey(Bloq): Litinski. 2023. Figure 4 (a). """ - n: int + n: 'SymbolicInt' base_point: ECPoint public_key: ECPoint - add_window_size: int = 1 - mul_window_size: int = 1 + add_window_size: 'SymbolicInt' = 1 + mul_window_size: 'SymbolicInt' = 1 @cached_property def signature(self) -> 'Signature': diff --git a/qualtran/bloqs/mod_arithmetic/mod_division.py b/qualtran/bloqs/mod_arithmetic/mod_division.py index 06f643525..d3472103b 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_division.py +++ b/qualtran/bloqs/mod_arithmetic/mod_division.py @@ -136,6 +136,9 @@ def on_classical_vals( def build_composite_bloq( self, bb: 'BloqBuilder', u: Soquet, v: Soquet, b: Soquet, a: Soquet, m: Soquet, f: Soquet ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") + u_arr = bb.split(u) v_arr = bb.split(v) @@ -542,6 +545,9 @@ def build_composite_bloq( f: Soquet, terminal_condition: Soquet, ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") + f = bb.add(XGate(), q=f) u = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=u) s = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=s) @@ -664,6 +670,9 @@ def signature(self) -> 'Signature': def build_composite_bloq( self, bb: 'BloqBuilder', x: Soquet, junk: Optional[Soquet] = None ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") + u = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) r = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) s = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py index e02e60712..0e8f3ca0b 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py @@ -29,6 +29,7 @@ BloqBuilder, BloqDocSpec, DecomposeNotImplementedError, + DecomposeTypeError, QBit, QInt, QMontgomeryUInt, @@ -91,6 +92,9 @@ def on_classical_vals(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']: return {'x': x} def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'SoquetT']: + if is_symbolic(self.dtype.bitsize): + raise DecomposeTypeError(f'symbolic decomposition is not supported for {self}') + # Allocate ancilla bits for sign and double. lower_bit = bb.allocate(n=1) sign = bb.allocate(n=1)