diff --git a/qlasskit/ast2ast/constantfolder.py b/qlasskit/ast2ast/constantfolder.py index 6ab330bf..fdc6ea79 100644 --- a/qlasskit/ast2ast/constantfolder.py +++ b/qlasskit/ast2ast/constantfolder.py @@ -91,9 +91,20 @@ def visit_BinOp(self, node): def visit_Call(self, node): self.generic_visit(node) if isinstance(node.func, ast.Name) and node.func.id in self.builtin_funcs: - if all(isinstance(arg, ast.Constant) for arg in node.args): + + def arg_tr(arg): + if isinstance(arg, ast.Tuple) or isinstance(arg, ast.List): + elts = [self.visit(elt) for elt in arg.elts] + if all(isinstance(elt, ast.Constant) for elt in elts): + return ast.Constant(value=[elt.value for elt in elts]) + + return arg + + args = list(map(arg_tr, node.args)) + + if all(isinstance(arg, ast.Constant) for arg in args): func = self.builtin_funcs[node.func.id] - args = [arg.value for arg in node.args] + args = [arg.value for arg in args] return ast.Constant(func(*args)) # type: ignore return node @@ -121,15 +132,15 @@ def visit_IfExp(self, node): if isinstance(node.test, ast.Constant): return node.body if node.test.value else node.orelse return node - - def visit_List(self, node): - elts = [self.visit(elt) for elt in node.elts] - if all(isinstance(elt, ast.Constant) for elt in elts): - return ast.Constant(value=[elt.value for elt in elts]) - return ast.List(elts=elts, ctx=node.ctx) - - def visit_Tuple(self, node): - elts = [self.visit(elt) for elt in node.elts] - if all(isinstance(elt, ast.Constant) for elt in elts): - return ast.Constant(value=tuple(elt.value for elt in elts)) - return ast.Tuple(elts=elts, ctx=node.ctx) \ No newline at end of file + + # def visit_List(self, node): + # elts = [self.visit(elt) for elt in node.elts] + # if all(isinstance(elt, ast.Constant) for elt in elts): + # return ast.Constant(value=[elt.value for elt in elts]) + # return ast.List(elts=elts, ctx=node.ctx) + + # def visit_Tuple(self, node): + # elts = [self.visit(elt) for elt in node.elts] + # if all(isinstance(elt, ast.Constant) for elt in elts): + # return ast.Constant(value=tuple(elt.value for elt in elts)) + # return ast.Tuple(elts=elts, ctx=node.ctx) diff --git a/test/test_ast2ast.py b/test/test_ast2ast.py index 6c17c7ec..07935721 100644 --- a/test/test_ast2ast.py +++ b/test/test_ast2ast.py @@ -63,7 +63,7 @@ def setUp(self): [ ("a + (13 - 12 + 1)", "a + 2"), # ( "a + 13 - 12 + 1", "a + 2" ), - ( "a + len([12])", "a + 1" ), + ("a + len([12])", "a + 1"), ("if True: a \nelse: b", "a"), ("a if False else b", "b"), ]