From d19501a6ceb5160be7f36d9c6d88faf88872e354 Mon Sep 17 00:00:00 2001 From: "Davide Gessa (dakk)" Date: Thu, 16 Nov 2023 17:44:37 +0100 Subject: [PATCH] qint3, fix subscript unrolling, subset example --- examples/subset_ex.py | 13 +++++-------- qlasskit/__init__.py | 3 ++- qlasskit/ast2ast.py | 9 +++++++-- qlasskit/qlassfun.py | 24 +++++++++++++++++++++++- qlasskit/types/__init__.py | 6 +++--- qlasskit/types/qint.py | 4 +++- test/test_algo_grover.py | 5 +---- test/test_qlassf_list.py | 9 +++++++++ 8 files changed, 53 insertions(+), 20 deletions(-) diff --git a/examples/subset_ex.py b/examples/subset_ex.py index cc8c0cd5..57fb13b1 100644 --- a/examples/subset_ex.py +++ b/examples/subset_ex.py @@ -4,7 +4,7 @@ from qiskit import Aer, QuantumCircuit, transpile from qiskit.visualization import plot_histogram -from qlasskit import Qint2, qlassf +from qlasskit import Qint2, Qint3, qlassf from qlasskit.algorithms import Grover @@ -19,15 +19,12 @@ def qiskit_simulate(qc): @qlassf -def subset_sum(ii: Tuple[Qint2, Qint2]) -> Qint2: - l = [0, 1, 2, 0] - ai, bi = ii - a = l[ai] - b = l[bi] - return a + b +def subset_sum(ii: Tuple[Qint2, Qint2]) -> Qint3: + l = [0, 5, 2, 3] + return l[ii[0]] + l[ii[1]] if ii[0] != ii[1] else 0 -algo = Grover(subset_sum, Qint2(3)) +algo = Grover(subset_sum, Qint3(7)) qc = algo.circuit().export("circuit", "qiskit") print(qc.draw("text")) diff --git a/qlasskit/__init__.py b/qlasskit/__init__.py index 350084e4..75de10b5 100644 --- a/qlasskit/__init__.py +++ b/qlasskit/__init__.py @@ -16,7 +16,7 @@ __version__ = "0.0.2" from .qcircuit import QCircuit, SupportedFrameworks, SupportedFramework # noqa: F401 -from .qlassfun import QlassF, qlassf # noqa: F401 +from .qlassfun import QlassF, qlassf, qlassf_a # noqa: F401 from .ast2ast import ast2ast # noqa: F401 from .ast2logic import exceptions # noqa: F401 from .types import ( # noqa: F401, F403 @@ -24,6 +24,7 @@ Qtype, Qint, Qint2, + Qint3, Qint4, Qint8, Qint12, diff --git a/qlasskit/ast2ast.py b/qlasskit/ast2ast.py index edfbd79f..ae229c8a 100644 --- a/qlasskit/ast2ast.py +++ b/qlasskit/ast2ast.py @@ -69,6 +69,8 @@ def _replace_types_annotations(ann, arg=None): class ASTRewriter(ast.NodeTransformer): + """Rewrites the ast to a simplified version""" + def __init__(self, env={}, ret=None): self.env = {} self.const = {} @@ -109,10 +111,13 @@ def visit_Subscript(self, node): if isinstance(_sval, ast.Name) and _sval.id in self.const: node.slice = self.const[_sval.id] - + # Unroll L[a] with (L[0] if a == 0 else L[1] if a == 1 ...) - elif isinstance(_sval, ast.Name) and _sval.id not in self.const: + elif (isinstance(_sval, ast.Name) and _sval.id not in self.const) or isinstance(_sval, ast.Subscript): if isinstance(node.value, ast.Name): + if node.value.id == 'Tuple': + return node + tup = self.env[node.value.id] else: tup = node.value diff --git a/qlasskit/qlassfun.py b/qlasskit/qlassfun.py index 077c3e9c..afc753f4 100644 --- a/qlasskit/qlassfun.py +++ b/qlasskit/qlassfun.py @@ -247,5 +247,27 @@ def qlassf( defs_fun = list(map(lambda q: q.to_logicfun(), defs)) return QlassF.from_function( - f, types, defs_fun, to_compile, compiler, uncompute=uncompute + f, + types, + defs_fun, + to_compile, + compiler, + uncompute=uncompute, + bool_optimizer=bool_optimizer, ) + + +def qlassf_a( + types: List[Qtype] = [], + defs: List[QlassF] = [], + to_compile: bool = True, + compiler: SupportedCompiler = "internal", + bool_optimizer: BoolOptimizerProfile = bestWorkingOptimizer, + uncompute: bool = True, +): + """Decorator with parameters for qlassf""" + + def _inner(fun): + return qlassf(fun, types, defs, to_compile, compiler, bool_optimizer, uncompute) + + return _inner diff --git a/qlasskit/types/__init__.py b/qlasskit/types/__init__.py index 62abbc35..2d900d6c 100644 --- a/qlasskit/types/__init__.py +++ b/qlasskit/types/__init__.py @@ -46,14 +46,14 @@ def _full_adder(c, a, b): # Carry x Sum from .qtype import Qtype, TExp, TType # noqa: F401, E402 from .qbool import Qbool # noqa: F401, E402 from .qlist import Qlist # noqa: F401, E402 -from .qint import Qint, Qint2, Qint4, Qint8, Qint12, Qint16 # noqa: F401, E402 +from .qint import Qint, Qint2, Qint3, Qint4, Qint8, Qint12, Qint16 # noqa: F401, E402 -BUILTIN_TYPES = [Qint2, Qint4, Qint8, Qint12, Qint16, Qlist] +BUILTIN_TYPES = [Qint2, Qint3, Qint4, Qint8, Qint12, Qint16, Qlist] def const_to_qtype(value: Any) -> TExp: if isinstance(value, int): - for det_type in [Qint2, Qint4, Qint8, Qint12, Qint16]: + for det_type in [Qint2, Qint3, Qint4, Qint8, Qint12, Qint16]: if value < 2**det_type.BIT_SIZE: return det_type.const(value) diff --git a/qlasskit/types/qint.py b/qlasskit/types/qint.py index f79351ec..2a918fdb 100644 --- a/qlasskit/types/qint.py +++ b/qlasskit/types/qint.py @@ -248,7 +248,9 @@ def bitwise_or(cls, tleft: TExp, tright: TExp) -> TExp: class Qint2(Qint): BIT_SIZE = 2 - +class Qint3(Qint): + BIT_SIZE = 3 + class Qint4(Qint): BIT_SIZE = 4 diff --git a/test/test_algo_grover.py b/test/test_algo_grover.py index 629a6543..0a1f27d4 100644 --- a/test/test_algo_grover.py +++ b/test/test_algo_grover.py @@ -74,10 +74,7 @@ def test_grover_subset_sum(self): f = """ def subset_sum(ii: Tuple[Qint2, Qint2]) -> Qint2: l = [0, 1, 2, 0] - ai, bi = ii - a = l[ai] - b = l[bi] - return a + b + return l[ii[0]] + l[ii[1]] """ qf = qlassf(f) algo = Grover(qf, Qint2(3)) diff --git a/test/test_qlassf_list.py b/test/test_qlassf_list.py index 9fc49e0f..8ba0afff 100644 --- a/test/test_qlassf_list.py +++ b/test/test_qlassf_list.py @@ -94,3 +94,12 @@ def test_list_access_with_var_on_tuple(self): f = "def test(ab: Tuple[Qint2, Qint2]) -> Qint2:\n\tc = [1,2,3,2]\n\tai,bi = ab\n\td = c[ai] + c[bi]\n\treturn d" qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) + + def test_list_access_with_var_on_tuple2(self): + # TODO: this fails on internal compiler + if self.compiler == "internal": + return + + f = "def test(ab: Tuple[Qint2, Qint2]) -> Qint2:\n\tc = [1,2,3,2]\n\td = c[ab[0]] + c[ab[1]]\n\treturn d" + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf)