diff --git a/qlasskit/ast2ast.py b/qlasskit/ast2ast.py new file mode 100644 index 00000000..9da4f3ae --- /dev/null +++ b/qlasskit/ast2ast.py @@ -0,0 +1,60 @@ +# Copyright 2023 Davide Gessa + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ast + + +class ASTRewriter(ast.NodeTransformer): + def visit_Call(self, node): + if not hasattr(node.func, "id"): + return node + + if node.func.id == "print": + return None + + elif node.func.id in ["min", "max"]: + if len(node.args) == 1: + if isinstance(node.args[0], ast.Tuple): + args = node.args[0].elts + else: + # TODO: not handled the case when the arg is a tuple; + # we can infer the type, in some way + return node.args[0] + else: + args = node.args + + op = ast.Gt() if node.func.id == "max" else ast.Lt() + + def iterif(arg_l): + if len(arg_l) == 1: + return arg_l[0] + else: + comps = [ + ast.Compare(left=arg_l[0], ops=[op], comparators=[l_it]) + for l_it in arg_l[1:] + ] + comp = ast.BoolOp(op=ast.And(), values=comps) + return ast.IfExp( + test=comp, body=arg_l[0], orelse=iterif(arg_l[1:]) + ) + + return iterif(args) + + else: + return node + + +def ast2ast(a_tree): + new_ast = ASTRewriter().visit(a_tree) + # print(ast.dump(new_ast)) + return new_ast diff --git a/qlasskit/ast2logic/env.py b/qlasskit/ast2logic/env.py index c73b3a1f..a306ce25 100644 --- a/qlasskit/ast2logic/env.py +++ b/qlasskit/ast2logic/env.py @@ -25,12 +25,12 @@ class Env: - def __init__(self): + def __init__(self) -> None: self.bindings: List[Binding] = [] self.types: List[TypeBinding] = [] for t in BUILTIN_TYPES: - self.bind_type((t.__name__, t)) + self.bind_type((t.__name__, t)) # type: ignore def bind_type(self, bb: TypeBinding): if self.know_type(bb[0]): diff --git a/qlasskit/ast2logic/t_expression.py b/qlasskit/ast2logic/t_expression.py index 33ae2a6d..aba70c35 100644 --- a/qlasskit/ast2logic/t_expression.py +++ b/qlasskit/ast2logic/t_expression.py @@ -203,22 +203,15 @@ def unfold(v_exps, op): # Call elif isinstance(expr, ast.Call): - if expr.func.id == "print": # type: ignore - return (None, None) - elif expr.func.id == "len" and len(expr.args) == 1: # type: ignore + if not hasattr(expr.func, "id"): + raise exceptions.ExpressionNotHandledException(expr) + + # This can be moved to ast2ast + if expr.func.id == "len" and len(expr.args) == 1: targ = translate_expression(expr.args[0], env) if isinstance(targ[1], List): return const_to_qtype(len(targ[1])) - - # elif expr.func.id in ["max", "min"]: # type: ignore - # targs = list(map(lambda e: translate_expression(e, env), expr.args)) - - # if len(targs) == 1 and isinstance(targs[0][1], List): - # targs = targs[0] - - # print (len(targs), targs) - - # # TODO: DOING + # print(ast.dump(expr)) raise exceptions.ExpressionNotHandledException(expr) diff --git a/qlasskit/ast2logic/t_statement.py b/qlasskit/ast2logic/t_statement.py index 1bc82067..a251a08f 100644 --- a/qlasskit/ast2logic/t_statement.py +++ b/qlasskit/ast2logic/t_statement.py @@ -80,7 +80,8 @@ def translate_statement( # noqa: C901 # Match elif isinstance(stmt, ast.Expr): - texp, vexp = translate_expression(stmt.value, env) + if hasattr(stmt, "value"): + texp, vexp = translate_expression(stmt.value, env) return [], env else: diff --git a/qlasskit/qlassf.py b/qlasskit/qlassf.py index 6d161f35..a627f92f 100644 --- a/qlasskit/qlassf.py +++ b/qlasskit/qlassf.py @@ -18,6 +18,7 @@ from typing import Callable, List, Tuple, Union # noqa: F401 from . import compiler +from .ast2ast import ast2ast from .ast2logic import Args, BoolExpList, flatten, translate_ast from .types import * # noqa: F403, F401 from .types import Qtype @@ -160,7 +161,7 @@ def from_function( exec(f) fun_ast = ast.parse(f if isinstance(f, str) else inspect.getsource(f)) - fun = fun_ast.body[0] + fun = ast2ast(fun_ast.body[0]) fun_name, args, fun_ret, exps = translate_ast(fun, types) original_f = eval(fun_name) if isinstance(f, str) else f diff --git a/test/test_qlassf_builtin.py b/test/test_qlassf_builtin.py index 7784d77a..dccb3327 100644 --- a/test/test_qlassf_builtin.py +++ b/test/test_qlassf_builtin.py @@ -46,12 +46,49 @@ def test_len4(self): self.assertEqual(qf.expressions[3][1], False) compute_and_compare_results(self, qf) + def test_min(self): + f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn min(a,b)" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + compute_and_compare_results(self, qf) + + def test_min_const(self): + f = "def test(a: Qint2) -> Qint2:\n\treturn min(a,3)" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + compute_and_compare_results(self, qf) + + def test_max(self): + f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn max(a,b)" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + compute_and_compare_results(self, qf) + + def test_max_of3(self): + f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn max(a,b,3)" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + compute_and_compare_results(self, qf) + + def test_max_const(self): + f = "def test(a: Qint2) -> Qint2:\n\treturn max(a,3)" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + compute_and_compare_results(self, qf) - # def test_max(self): - # f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn max(a,b)" + # TODO: fixed by cast + # def test_max_const2(self): + # f = "def test(a: Qint4) -> Qint2:\n\treturn max(a,3)" # qf = qlassf(f, to_compile=COMPILATION_ENABLED) # compute_and_compare_results(self, qf) - + + # TODO: not handled + # def test_max_tuple(self): + # f = "def test(a: Tuple[Qint2, Qint2]) -> Qint2:\n\treturn max(a)" + # qf = qlassf(f, to_compile=COMPILATION_ENABLED) + # compute_and_compare_results(self, qf) + + def test_max_tuple_const(self): + f = "def test(a: Qint2, b: Qint2) -> Qint2:\n\treturn max((a, b))" + qf = qlassf(f, to_compile=COMPILATION_ENABLED) + compute_and_compare_results(self, qf) + + # TODO: not handled # def test_max_tuple(self): # f = "def test(a: Tuple[Qint2, Qint2]) -> Qint2:\n\treturn max(a)" # qf = qlassf(f, to_compile=COMPILATION_ENABLED)