Skip to content

Commit

Permalink
handle call constant folding with list and tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Jul 2, 2024
1 parent 23198b4 commit 28952f2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
39 changes: 25 additions & 14 deletions qlasskit/ast2ast/constantfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,20 @@ def visit_BinOp(self, node):
def visit_Call(self, node):
self.generic_visit(node)
if isinstance(node.func, ast.Name) and node.func.id in self.builtin_funcs:
if all(isinstance(arg, ast.Constant) for arg in node.args):

def arg_tr(arg):
if isinstance(arg, ast.Tuple) or isinstance(arg, ast.List):
elts = [self.visit(elt) for elt in arg.elts]
if all(isinstance(elt, ast.Constant) for elt in elts):
return ast.Constant(value=[elt.value for elt in elts])

return arg

args = list(map(arg_tr, node.args))

if all(isinstance(arg, ast.Constant) for arg in args):
func = self.builtin_funcs[node.func.id]
args = [arg.value for arg in node.args]
args = [arg.value for arg in args]
return ast.Constant(func(*args)) # type: ignore
return node

Expand Down Expand Up @@ -121,15 +132,15 @@ def visit_IfExp(self, node):
if isinstance(node.test, ast.Constant):
return node.body if node.test.value else node.orelse
return node
def visit_List(self, node):
elts = [self.visit(elt) for elt in node.elts]
if all(isinstance(elt, ast.Constant) for elt in elts):
return ast.Constant(value=[elt.value for elt in elts])
return ast.List(elts=elts, ctx=node.ctx)

def visit_Tuple(self, node):
elts = [self.visit(elt) for elt in node.elts]
if all(isinstance(elt, ast.Constant) for elt in elts):
return ast.Constant(value=tuple(elt.value for elt in elts))
return ast.Tuple(elts=elts, ctx=node.ctx)

# def visit_List(self, node):
# elts = [self.visit(elt) for elt in node.elts]
# if all(isinstance(elt, ast.Constant) for elt in elts):
# return ast.Constant(value=[elt.value for elt in elts])
# return ast.List(elts=elts, ctx=node.ctx)

# def visit_Tuple(self, node):
# elts = [self.visit(elt) for elt in node.elts]
# if all(isinstance(elt, ast.Constant) for elt in elts):
# return ast.Constant(value=tuple(elt.value for elt in elts))
# return ast.Tuple(elts=elts, ctx=node.ctx)
2 changes: 1 addition & 1 deletion test/test_ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def setUp(self):
[
("a + (13 - 12 + 1)", "a + 2"),
# ( "a + 13 - 12 + 1", "a + 2" ),
( "a + len([12])", "a + 1" ),
("a + len([12])", "a + 1"),
("if True: a \nelse: b", "a"),
("a if False else b", "b"),
]
Expand Down

0 comments on commit 28952f2

Please sign in to comment.