From 2ec3dc8e295cac38a50f2d08e363c81b3e780ef9 Mon Sep 17 00:00:00 2001 From: "Davide Gessa (dakk)" Date: Sat, 6 Jul 2024 17:38:03 +0200 Subject: [PATCH] access tuple of tuple by var --- qlasskit/ast2ast/ast2ast.py | 2 + qlasskit/ast2ast/astrewriter.py | 73 +++++++++++++++++++++++++++++- test/qlassf/test_matrix.py | 80 ++++++++++++++++----------------- test/qlassf/test_tuple.py | 17 ++++--- test/test_decopt.py | 17 ++++--- 5 files changed, 131 insertions(+), 58 deletions(-) diff --git a/qlasskit/ast2ast/ast2ast.py b/qlasskit/ast2ast/ast2ast.py index bb023db1..1f964349 100644 --- a/qlasskit/ast2ast/ast2ast.py +++ b/qlasskit/ast2ast/ast2ast.py @@ -35,6 +35,8 @@ def ast2ast(a_tree): if sys.version_info < (3, 9): a_tree = IndexReplacer().visit(a_tree) + # Matrix translator + # Fold constants a_tree = ConstantFolder().visit(a_tree) diff --git a/qlasskit/ast2ast/astrewriter.py b/qlasskit/ast2ast/astrewriter.py index c0d98875..98342de0 100644 --- a/qlasskit/ast2ast/astrewriter.py +++ b/qlasskit/ast2ast/astrewriter.py @@ -154,13 +154,84 @@ def __unroll_arg(self, arg): def generic_visit(self, node): return super().generic_visit(node) - def visit_Subscript(self, node): + def visit_Subscript(self, node): # noqa: C901 _sval = node.slice # Replace L[a] with const a, to L[const] if isinstance(_sval, ast.Name) and _sval.id in self.const: node.slice = self.const[_sval.id] + # Handle inner access L[i][j] + elif ( + isinstance(node, ast.Subscript) + and isinstance(node.value, ast.Subscript) + and isinstance(node.value.value, ast.Name) + and isinstance(node.value.slice, ast.Name) + and isinstance(node.slice, ast.Name) + ): + nname = node.value.value.id + iname = node.value.slice.id + jname = node.slice.id + + def create_if_exp(i, j, max_i, max_j): + if i == max_i and j == max_j: + return ast.Subscript( + value=ast.Subscript( + value=ast.Name(id=nname, ctx=ast.Load()), + slice=ast.Constant(value=i), + ctx=ast.Load(), + ), + slice=ast.Constant(value=j), + ctx=ast.Load(), + ) + else: + next_j = j + 1 if j < max_j else 0 + next_i = i if j < max_j else i + 1 + return ast.IfExp( + test=ast.BoolOp( + op=ast.And(), + values=[ + ast.Compare( + left=ast.Name(id=iname, ctx=ast.Load()), + ops=[ast.Eq()], + comparators=[ast.Constant(value=i)], + ), + ast.Compare( + left=ast.Name(id=jname, ctx=ast.Load()), + ops=[ast.Eq()], + comparators=[ast.Constant(value=j)], + ), + ], + ), + body=ast.Subscript( + value=ast.Subscript( + value=ast.Name(id=nname, ctx=ast.Load()), + slice=ast.Constant(value=i), + ctx=ast.Load(), + ), + slice=ast.Constant(value=j), + ctx=ast.Load(), + ), + orelse=create_if_exp(next_i, next_j, max_i, max_j), + ) + + # Infer i and j sizes from env['a'] + a_type = self.env[nname] + + # self.env[nname] is a constant + if isinstance(a_type, ast.Tuple): + max_i = len(a_type.elts) - 1 + max_j = len(a_type.elts[0].elts) - 1 # type: ignore + # self.env[nname] is a type annotation + else: + outer_tuple = a_type.slice + max_i = len(outer_tuple.elts) - 1 + inner_tuple = outer_tuple.elts + max_j = len(inner_tuple) - 1 + + # Create the IfExp structure + return create_if_exp(0, 0, max_i, max_j) + # Unroll L[a] with (L[0] if a == 0 else L[1] if a == 1 ...) elif (isinstance(_sval, ast.Name) and _sval.id not in self.const) or isinstance( _sval, ast.Subscript diff --git a/test/qlassf/test_matrix.py b/test/qlassf/test_matrix.py index 8b54645f..484e25d1 100644 --- a/test/qlassf/test_matrix.py +++ b/test/qlassf/test_matrix.py @@ -89,44 +89,42 @@ def test_matrix_len(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) - # TODO: this raises Not a tuple in ast2ast visit subscript with not constant _sval: Subscript - # (value=Name(id='a', ctx=Load()), slice=Name(id='i', ctx=Load()), ctx=Load()) - # def test_matrix_access2(self): - # f = ( - # "def test(a: Qmatrix[Qint[2], 2, 2]) -> Qint[2]:\n\ti = 1\n" - # "\tj = i + 1\n\treturn a[i][j]" - # ) - # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) - # compute_and_compare_results(self, qf) - # def test_matrix_access3(self): - # f = ( - # "def test(a: Qmatrix[Qint[2], 2, 2], i: Qint[2], j: Qint[2]) -> Qint[2]:\n" - # "\treturn a[i][j]" - # ) - # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) - # compute_and_compare_results(self, qf) - - # def test_matrix_access_with_var(self): - # f = "def test(a: Qint[2]) -> Qint[2]:\n\tc = [[1,2],[3,4]]\n\tb = c[a][a]\n\treturn b" - # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) - # compute_and_compare_results(self, qf) - - # def test_list_access_with_var_on_tuple(self): - # # TODO: this fails on internal compiler - # if self.compiler == "internal": - # return - - # f = ("def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n\tai,bi = ab\n" - # "\td = c[ai] + c[bi]\n\treturn d") - # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) - # compute_and_compare_results(self, qf) - - # def test_list_access_with_var_on_tuple2(self): - # # TODO: this fails on internal compiler - # if self.compiler == "internal": - # return - - # f = ("def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n" - # "\td = c[ab[0]] + c[ab[1]]\n\treturn d") - # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) - # compute_and_compare_results(self, qf) + def test_matrix_access2(self): + f = ( + "def test(a: Qmatrix[Qint[2], 2, 2]) -> Qint[2]:\n\ti = 0\n" + "\tj = i + 1\n\treturn a[i][j]" + ) + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_matrix_access3(self): + f = ( + "def test(a: Qmatrix[Qint[2], 2, 2], i: Qint[2], j: Qint[2]) -> Qint[2]:\n" + "\treturn a[i][j]" + ) + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_matrix_access_with_var(self): + f = ( + "def test(a: Qint[2]) -> Qint[4]:\n\tc = [[1,2,7,8],[3,4,8,8],[5,6,9,1],[1,2,7,8]]\n" + "\tb = c[a][a]\n\treturn b" + ) + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_list_access_with_var_on_tuple(self): + f = ( + "def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n\tai,bi = ab\n" + "\td = c[ai] + c[bi]\n\treturn d" + ) + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) + + def test_list_access_with_var_on_tuple2(self): + f = ( + "def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n" + "\td = c[ab[0]] + c[ab[1]]\n\treturn d" + ) + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) diff --git a/test/qlassf/test_tuple.py b/test/qlassf/test_tuple.py index 1cdf3927..67980853 100644 --- a/test/qlassf/test_tuple.py +++ b/test/qlassf/test_tuple.py @@ -199,12 +199,11 @@ def test_tuple_iterator_vartuple(self): qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) compute_and_compare_results(self, qf) - # TODO: failing for #63 - # def test_tuple_of_tuple_var_access(self): - # f = ( - # "def test(a: Tuple[Tuple[bool, bool], Tuple[bool, bool]], " - # "i: Qint[2], j: Qint[2]) -> bool:\n" - # "\treturn a[i][j]" - # ) - # qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) - # compute_and_compare_results(self, qf) + def test_tuple_of_tuple_var_access(self): + f = ( + "def test(a: Tuple[Tuple[bool, bool, bool], Tuple[bool, bool, bool], " + "Tuple[bool, bool, bool]], i: Qint[2], j: Qint[2]) -> bool:\n" + "\treturn a[i][j]" + ) + qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler) + compute_and_compare_results(self, qf) diff --git a/test/test_decopt.py b/test/test_decopt.py index 54efb041..0149ce1e 100644 --- a/test/test_decopt.py +++ b/test/test_decopt.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import random import itertools +import random +import unittest from qlasskit import QCircuit, boolopt, qlassf from qlasskit.decompiler import circuit_boolean_optimizer @@ -95,14 +95,17 @@ def test_circuit_boolean_optimizer_random_2(self): qc_n_un = qiskit_unitary(qc_n.export()) self.assertEqual(qc_un, qc_n_un) - def test_circuit_boolean_optimizer_random_x_cx(self): g_simp = 0 - - possib = [(gates.CX, x[0], x[1]) for x in itertools.permutations([0,1,2],r=2)] - possib += [(gates.X, x[0]) for x in itertools.permutations([0,1,2],r=1)] - for i in random.choices(list(itertools.combinations_with_replacement(possib, r=8)), k=32): + possib = [ + (gates.CX, x[0], x[1]) for x in itertools.permutations([0, 1, 2], r=2) + ] + possib += [(gates.X, x[0]) for x in itertools.permutations([0, 1, 2], r=1)] + + for i in random.choices( + list(itertools.combinations_with_replacement(possib, r=8)), k=32 + ): qc = QCircuit(3) for g in i: qc.append(g[0](), g[1:])