Skip to content

Commit

Permalink
add an ast2ast pass, ijmplement min and max
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 18, 2023
1 parent 33c6bef commit 3a9ba01
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 20 deletions.
60 changes: 60 additions & 0 deletions qlasskit/ast2ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2023 Davide Gessa

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast


class ASTRewriter(ast.NodeTransformer):
def visit_Call(self, node):
if not hasattr(node.func, "id"):
return node

if node.func.id == "print":
return None

elif node.func.id in ["min", "max"]:
if len(node.args) == 1:
if isinstance(node.args[0], ast.Tuple):
args = node.args[0].elts
else:
# TODO: not handled the case when the arg is a tuple;
# we can infer the type, in some way
return node.args[0]
else:
args = node.args

op = ast.Gt() if node.func.id == "max" else ast.Lt()

def iterif(arg_l):
if len(arg_l) == 1:
return arg_l[0]
else:
comps = [
ast.Compare(left=arg_l[0], ops=[op], comparators=[l_it])
for l_it in arg_l[1:]
]
comp = ast.BoolOp(op=ast.And(), values=comps)
return ast.IfExp(
test=comp, body=arg_l[0], orelse=iterif(arg_l[1:])
)

return iterif(args)

else:
return node


def ast2ast(a_tree):
new_ast = ASTRewriter().visit(a_tree)
# print(ast.dump(new_ast))
return new_ast
4 changes: 2 additions & 2 deletions qlasskit/ast2logic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@


class Env:
def __init__(self):
def __init__(self) -> None:
self.bindings: List[Binding] = []
self.types: List[TypeBinding] = []

for t in BUILTIN_TYPES:
self.bind_type((t.__name__, t))
self.bind_type((t.__name__, t)) # type: ignore

def bind_type(self, bb: TypeBinding):
if self.know_type(bb[0]):
Expand Down
19 changes: 6 additions & 13 deletions qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,22 +203,15 @@ def unfold(v_exps, op):

# Call
elif isinstance(expr, ast.Call):
if expr.func.id == "print": # type: ignore
return (None, None)
elif expr.func.id == "len" and len(expr.args) == 1: # type: ignore
if not hasattr(expr.func, "id"):
raise exceptions.ExpressionNotHandledException(expr)

# This can be moved to ast2ast
if expr.func.id == "len" and len(expr.args) == 1:
targ = translate_expression(expr.args[0], env)
if isinstance(targ[1], List):
return const_to_qtype(len(targ[1]))

# elif expr.func.id in ["max", "min"]: # type: ignore
# targs = list(map(lambda e: translate_expression(e, env), expr.args))

# if len(targs) == 1 and isinstance(targs[0][1], List):
# targs = targs[0]

# print (len(targs), targs)

# # TODO: DOING


# print(ast.dump(expr))
raise exceptions.ExpressionNotHandledException(expr)
Expand Down
3 changes: 2 additions & 1 deletion qlasskit/ast2logic/t_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def translate_statement( # noqa: C901
# Match

elif isinstance(stmt, ast.Expr):
texp, vexp = translate_expression(stmt.value, env)
if hasattr(stmt, "value"):
texp, vexp = translate_expression(stmt.value, env)
return [], env

else:
Expand Down
3 changes: 2 additions & 1 deletion qlasskit/qlassf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable, List, Tuple, Union # noqa: F401

from . import compiler
from .ast2ast import ast2ast
from .ast2logic import Args, BoolExpList, flatten, translate_ast
from .types import * # noqa: F403, F401
from .types import Qtype
Expand Down Expand Up @@ -160,7 +161,7 @@ def from_function(
exec(f)

fun_ast = ast.parse(f if isinstance(f, str) else inspect.getsource(f))
fun = fun_ast.body[0]
fun = ast2ast(fun_ast.body[0])

fun_name, args, fun_ret, exps = translate_ast(fun, types)
original_f = eval(fun_name) if isinstance(f, str) else f
Expand Down
43 changes: 40 additions & 3 deletions test/test_qlassf_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,49 @@ def test_len4(self):
self.assertEqual(qf.expressions[3][1], False)
compute_and_compare_results(self, qf)

def test_min(self):
f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn min(a,b)"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

def test_min_const(self):
f = "def test(a: Qint2) -> Qint2:\n\treturn min(a,3)"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

def test_max(self):
f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn max(a,b)"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

def test_max_of3(self):
f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn max(a,b,3)"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

def test_max_const(self):
f = "def test(a: Qint2) -> Qint2:\n\treturn max(a,3)"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

# def test_max(self):
# f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn max(a,b)"
# TODO: fixed by cast
# def test_max_const2(self):
# f = "def test(a: Qint4) -> Qint2:\n\treturn max(a,3)"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED)
# compute_and_compare_results(self, qf)


# TODO: not handled
# def test_max_tuple(self):
# f = "def test(a: Tuple[Qint2, Qint2]) -> Qint2:\n\treturn max(a)"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED)
# compute_and_compare_results(self, qf)

def test_max_tuple_const(self):
f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn max((a, b))"
qf = qlassf(f, to_compile=COMPILATION_ENABLED)
compute_and_compare_results(self, qf)

# TODO: not handled
# def test_max_tuple(self):
# f = "def test(a: Tuple[Qint2, Qint2]) -> Qint2:\n\treturn max(a)"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED)
Expand Down

0 comments on commit 3a9ba01

Please sign in to comment.