Skip to content

Commit

Permalink
allow new type injection on qlassf
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 16, 2023
1 parent f916dc6 commit 20b03a4
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
4 changes: 2 additions & 2 deletions qlasskit/ast2logic/t_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def translate_ast(fun) -> LogicFun:
# env contains names visible from the current scope
env = Env()

args: Args = translate_arguments(fun.args.args)
args: Args = translate_arguments(fun.args.args, env)

[env.bind(arg) for arg in args]

if not fun.returns:
raise exceptions.NoReturnTypeException()

ret_ = translate_argument(fun.returns) # TODO: we need to preserve this
ret_ = translate_argument(fun.returns, env) # TODO: we need to preserve this
ret_size = len(ret_)

exps = []
Expand Down
13 changes: 10 additions & 3 deletions qlasskit/qlassf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from . import compiler
from .ast2logic import Args, BoolExpList, flatten, translate_ast
from .types import * # noqa: F403, F401
from .types import Qtype

MAX_TRUTH_TABLE_SIZE = 20

Expand Down Expand Up @@ -151,7 +152,9 @@ def f(self) -> Callable:
return self.original_f

@staticmethod
def from_function(f: Union[str, Callable], to_compile=True) -> "QlassF":
def from_function(
f: Union[str, Callable], types: List[Qtype] = [], to_compile: bool = True
) -> "QlassF":
"""Create a QlassF from a function or a string containing a function"""
if isinstance(f, str):
exec(f)
Expand All @@ -168,10 +171,14 @@ def from_function(f: Union[str, Callable], to_compile=True) -> "QlassF":
return qf


def qlassf(f: Union[str, Callable], to_compile=True) -> QlassF:
def qlassf(
f: Union[str, Callable], types: List[Qtype] = [], to_compile: bool = True
) -> QlassF:
"""Decorator / function creating a QlassF object
Args:
f: String or function
types (List[Qtype], optional): A list of new types to bind
to_compile (bool, optional): Compile the circuit after parsing
"""
return QlassF.from_function(f, to_compile)
return QlassF.from_function(f, types, to_compile)
2 changes: 2 additions & 0 deletions qlasskit/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
from .qbool import Qbool # noqa: F401
from .qint import Qint, Qint2, Qint4, Qint8, Qint12, Qint16 # noqa: F401
from .qtype import Qtype, TExp, TType # noqa: F401

BUILTIN_TYPES = [Qint2, Qint4, Qint8, Qint12, Qint16]
16 changes: 8 additions & 8 deletions test/test_ast2logic_t_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,62 +25,62 @@ def test_unknown_type(self):
ann_ast = ast.parse(f).body[0].annotation
self.assertRaises(
exceptions.UnknownTypeException,
lambda ann_ast: ast2logic.translate_argument(ann_ast, "a"),
lambda ann_ast: ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a"),
ann_ast,
)

def test_bool(self):
f = "a: bool"
ann_ast = ast.parse(f).body[0].annotation
c = ast2logic.translate_argument(ann_ast, "a")
c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a")
self.assertEqual(c.name, "a")
self.assertEqual(c.ttype, bool)
self.assertEqual(c.bitvec, ["a"])

def test_qint2(self):
f = "a: Qint2"
ann_ast = ast.parse(f).body[0].annotation
c = ast2logic.translate_argument(ann_ast, "a")
c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a")
self.assertEqual(c.name, "a")
self.assertEqual(c.ttype, Qint2)
self.assertEqual(c.bitvec, ["a.0", "a.1"])

def test_qint4(self):
f = "a: Qint4"
ann_ast = ast.parse(f).body[0].annotation
c = ast2logic.translate_argument(ann_ast, "a")
c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a")
self.assertEqual(c.name, "a")
self.assertEqual(c.ttype, Qint4)
self.assertEqual(c.bitvec, ["a.0", "a.1", "a.2", "a.3"])

def test_tuple(self):
f = "a: Tuple[bool, bool]"
ann_ast = ast.parse(f).body[0].annotation
c = ast2logic.translate_argument(ann_ast, "a")
c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a")
self.assertEqual(c.name, "a")
self.assertEqual(c.ttype, Tuple[bool, bool])
self.assertEqual(c.bitvec, ["a.0", "a.1"])

def test_tuple_of_tuple(self):
f = "a: Tuple[Tuple[bool, bool], bool]"
ann_ast = ast.parse(f).body[0].annotation
c = ast2logic.translate_argument(ann_ast, "a")
c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a")
self.assertEqual(c.name, "a")
self.assertEqual(c.ttype, Tuple[Tuple[bool, bool], bool])
self.assertEqual(c.bitvec, ["a.0.0", "a.0.1", "a.1"])

def test_tuple_of_tuple2(self):
f = "a: Tuple[bool, Tuple[bool, bool]]"
ann_ast = ast.parse(f).body[0].annotation
c = ast2logic.translate_argument(ann_ast, "a")
c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a")
self.assertEqual(c.name, "a")
self.assertEqual(c.ttype, Tuple[bool, Tuple[bool, bool]])
self.assertEqual(c.bitvec, ["a.0", "a.1.0", "a.1.1"])

def test_tuple_of_int2(self):
f = "a: Tuple[Qint2, Qint2]"
ann_ast = ast.parse(f).body[0].annotation
c = ast2logic.translate_argument(ann_ast, "a")
c = ast2logic.translate_argument(ann_ast, ast2logic.Env(), "a")
self.assertEqual(c.name, "a")
self.assertEqual(c.ttype, Tuple[Qint2, Qint2])
self.assertEqual(
Expand Down

0 comments on commit 20b03a4

Please sign in to comment.