Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update registerhandling to use i1* pointer #61

Merged
merged 11 commits into from
Aug 9, 2023
2 changes: 1 addition & 1 deletion _metadata.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__extension_version__ = "0.2.0rc14"
__extension_version__ = "0.2.0rc15"
__extension_name__ = "pytket-qir"
5 changes: 5 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
~~~~~~~~~

0.2.0rc15 (August 2023)
-----------------------
* update the classical register handling to use i1* pointer
* update pytket requirement to 1.18

0.2.0rc14 (July 2023)
---------------------
* add simplification for RangePredicate in case of equal bounds
Expand Down
10 changes: 6 additions & 4 deletions pytket/qir/conversion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,18 @@ def pytket_to_qir(
def keep_line(line: str) -> bool:
return (
("@__quantum__qis__read_result__body" not in line)
and ("@set_one_bit_in_reg" not in line)
and ("@reg2var" not in line)
and ("@read_bit_from_reg" not in line)
and ("@set_all_bits_in_reg" not in line)
and ("@set_creg_bit" not in line)
and ("@get_creg_bit" not in line)
and ("@set_creg_to_int" not in line)
and ("@get_int_from_creg" not in line)
and ("@create_creg" not in line)
)

result = "\n".join(filter(keep_line, initial_result.split("\n")))

# replace the use of the removed register variable with i64 0
result = result.replace("i64 %0", "i64 0")
result = result.replace("i64 %3", "i64 0")

for _ in range(10):
result = result.replace("\n\n\n\n", "\n\n")
Expand Down
148 changes: 84 additions & 64 deletions pytket/qir/conversion/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
self.module = module
self.wasm_int_type = pyqir.IntType(self.module.context, wasm_int_type)
self.qir_int_type = pyqir.IntType(self.module.context, qir_int_type)
self.qir_i1p_type = pyqir.PointerType(pyqir.IntType(self.module.context, 1))
self.qir_bool_type = pyqir.IntType(self.module.context, 1)
self.qubit_type = pyqir.qubit_type(self.module.context)
self.result_type = pyqir.result_type(self.module.context)
Expand All @@ -129,35 +130,35 @@ def __init__(
self.set_cregs: dict[str, list] = {} # Keep track of set registers.
self.ssa_vars: dict[str, Value] = {} # Keep track of set ssa variables.

# i1 read_bit_from_reg(i64 reg, i64 index)
self.read_bit_from_reg = self.module.module.add_external_function(
"read_bit_from_reg",
# i1 get_creg_bit(i1* creg, i64 index)
self.get_creg_bit = self.module.module.add_external_function(
"get_creg_bit",
pyqir.FunctionType(
pyqir.IntType(self.module.module.context, 1),
[self.qir_int_type] * 2,
[self.qir_i1p_type, self.qir_int_type],
),
)

# void set_one_bit_in_reg(i64 reg, i64 index, i1 value)
self.set_one_bit_in_reg = self.module.module.add_external_function(
"set_one_bit_in_reg",
# void set_creg_bit(i1* creg, i64 index, i1 value)
self.set_creg_bit = self.module.module.add_external_function(
"set_creg_bit",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
self.qir_int_type,
self.qir_i1p_type,
self.qir_int_type,
pyqir.IntType(self.module.module.context, 1),
],
),
)

# void set_all_bits_in_reg(i64 reg, i64 value)
self.set_all_bits_in_reg = self.module.module.add_external_function(
"set_all_bits_in_reg",
# void set_creg_to_int(i1* creg, i64 value)
self.set_creg_to_int = self.module.module.add_external_function(
"set_creg_to_int",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
self.qir_int_type,
self.qir_i1p_type,
self.qir_int_type,
],
),
Expand All @@ -172,12 +173,23 @@ def __init__(
),
)

# i64 reg2var(i1, i1, i1, ...)
self.reg2var = self.module.module.add_external_function(
"reg2var",
# i1* create_creg(i64 size)
self.create_creg = self.module.module.add_external_function(
"create_creg",
pyqir.FunctionType(
pyqir.IntType(self.module.module.context, qir_int_type),
[pyqir.IntType(self.module.module.context, 1)] * qir_int_type,
self.qir_i1p_type,
[pyqir.IntType(self.module.module.context, qir_int_type)],
),
)

# i64 get_int_from_creg(i1* creg)
self.get_int_from_creg = self.module.module.add_external_function(
"get_int_from_creg",
pyqir.FunctionType(
self.qir_int_type,
[
self.qir_i1p_type,
],
),
)

Expand Down Expand Up @@ -374,6 +386,13 @@ def _get_optype_and_params(self, op: Op) -> tuple[OpType, Sequence[float]]:
params = op.params
return (optype, params)

def _get_i64_ssa_reg(self, name: str) -> Value:
ssa_var = self.module.builder.call(
self.get_int_from_creg,
[self.ssa_vars[name]],
)
return ssa_var

def _to_qis_qubits(self, qubits: list[Qubit]) -> Sequence[Qubit]:
return [self.module.module.qubits[qubit.index[0]] for qubit in qubits]

Expand All @@ -392,21 +411,14 @@ def _to_qis_bits(self, args: list[Bit]) -> Sequence[Value]:
def _reg2ssa_var(self, bit_reg: BitRegister, int_size: int) -> Value:
"""Convert a BitRegister to an SSA variable using pyqir types."""
reg_name = bit_reg[0].reg_name
if (
reg_name not in self.ssa_vars
): # Check if the register has been previously set.
# Check if the register has been previously set. If not, initialise to 0.
if reg_value := self.set_cregs.get(reg_name):
bit_reg = reg_value
value = sum([n * 2**k for k, n in enumerate(reg_value)])
return pyqir.const(self.qir_int_type, value)
else:
bit_reg = [False] * len(bit_reg)
if (size := len(bit_reg)) <= int_size: # Widening by zero-padding.
bool_reg = bit_reg + [False] * (int_size - size)
else: # Narrowing by truncation.
bool_reg = bit_reg[:int_size]
ssa_var = cast(Value, self.module.builder.call(self.reg2var, [*bool_reg])) # type: ignore # noqa: E501
if reg_name not in self.ssa_vars:
if len(bit_reg) > int_size:
raise ValueError(
f"Classical register should only have the size of {int_size}"
)
ssa_var = self.module.builder.call( # type: ignore
self.create_creg, [pyqir.const(self.qir_int_type, len(bit_reg))]
)
self.ssa_vars[reg_name] = ssa_var
return ssa_var
else:
Expand Down Expand Up @@ -449,7 +461,7 @@ def _get_ssa_from_cl_reg_op(
)
return output_instruction # type: ignore
elif type(reg) == BitRegister:
return self.ssa_vars[reg.name]
return self._get_i64_ssa_reg(reg.name)
elif type(reg) == int:
return pyqir.const(self.qir_int_type, reg)
else:
Expand All @@ -460,7 +472,7 @@ def _get_ssa_from_cl_bit_op(
) -> Value:
if type(bit) == Bit:
result = module.builder.call(
self.read_bit_from_reg,
self.get_creg_bit,
[
self.ssa_vars[bit.reg_name],
pyqir.const(self.qir_int_type, bit.index[0]),
Expand Down Expand Up @@ -499,14 +511,14 @@ def circuit_to_module(
result = module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.lower),
self.ssa_vars[registername],
self._get_i64_ssa_reg(registername),
)

condition_bit_index = command.args[-1].index[0]
result_registername = command.args[-1].reg_name

self.module.builder.call(
self.set_one_bit_in_reg,
self.set_creg_bit,
[
self.ssa_vars[result_registername],
pyqir.const(self.qir_int_type, condition_bit_index),
Expand All @@ -521,10 +533,14 @@ def circuit_to_module(
registername = command.args[0].reg_name

lower_cond = module.module.builder.icmp(
pyqir.IntPredicate.SGT, lower_qir, self.ssa_vars[registername]
pyqir.IntPredicate.SGT,
lower_qir,
self._get_i64_ssa_reg(registername),
)
upper_cond = module.module.builder.icmp(
pyqir.IntPredicate.SGT, self.ssa_vars[registername], upper_qir
pyqir.IntPredicate.SGT,
self._get_i64_ssa_reg(registername),
upper_qir,
)

result = module.module.builder.and_(lower_cond, upper_cond)
Expand All @@ -533,7 +549,7 @@ def circuit_to_module(
registername = command.args[-1].reg_name

self.module.builder.call(
self.set_one_bit_in_reg,
self.set_creg_bit,
[
self.ssa_vars[registername],
pyqir.const(self.qir_int_type, condition_bit_index),
Expand Down Expand Up @@ -569,7 +585,7 @@ def condition_block_false() -> None:
assert condition_name in self.ssa_vars

ssabool = module.builder.call(
self.read_bit_from_reg,
self.get_creg_bit,
[
self.ssa_vars[condition_name],
pyqir.const(self.qir_int_type, condition_bit_index),
Expand Down Expand Up @@ -610,7 +626,7 @@ def condition_block() -> None:
ssabool = module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.value),
self.ssa_vars[condition_name],
self._get_i64_ssa_reg(condition_name),
)

module.module.builder.if_(
Expand Down Expand Up @@ -775,7 +791,7 @@ def condition_block() -> None:
)

self.module.builder.call(
self.set_one_bit_in_reg,
self.set_creg_bit,
[
self.ssa_vars[command.bits[0].reg_name],
pyqir.const(self.qir_int_type, command.bits[0].index[0]),
Expand All @@ -793,18 +809,18 @@ def condition_block() -> None:
0 # defines the default value for ops that returns bool, see below
)
outputs = command.args[-1].reg_name
ssa_left = self.ssa_vars[list(self.ssa_vars)[0]] # set default value
ssa_right = self.ssa_vars[list(self.ssa_vars)[0]] # set default value

#
ssa_left = (self._get_i64_ssa_reg(list(self.ssa_vars)[0]),)
ssa_right = (self._get_i64_ssa_reg(list(self.ssa_vars)[0]),)

if type(op.get_exp()) in _TK_CLOPS_TO_PYQIR_REG:
# classical ops acting on registers returning register
ssa_left = self._get_ssa_from_cl_reg_op(
op.get_exp().args[0], module
ssa_left = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[0], module),
)
ssa_right = self._get_ssa_from_cl_reg_op(
op.get_exp().args[1], module
ssa_right = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[1], module),
)

# add function to module
Expand All @@ -814,11 +830,13 @@ def condition_block() -> None:

elif type(op.get_exp()) in _TK_CLOPS_TO_PYQIR_BIT:
# classical ops acting on bits returning bit
ssa_left = self._get_ssa_from_cl_bit_op(
op.get_exp().args[0], module
ssa_left = cast( # type: ignore
Value,
self._get_ssa_from_cl_bit_op(op.get_exp().args[0], module),
)
ssa_right = self._get_ssa_from_cl_bit_op(
op.get_exp().args[1], module
ssa_right = cast( # type: ignore
Value,
self._get_ssa_from_cl_bit_op(op.get_exp().args[1], module),
)

# add function to module
Expand All @@ -830,11 +848,13 @@ def condition_block() -> None:

elif type(op.get_exp()) in _TK_CLOPS_TO_PYQIR_REG_BOOL:
# classical ops acting on registers returning bit
ssa_left = self._get_ssa_from_cl_reg_op(
op.get_exp().args[0], module
ssa_left = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[0], module),
)
ssa_right = self._get_ssa_from_cl_reg_op(
op.get_exp().args[1], module
ssa_right = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[1], module),
)

# add function to module
Expand All @@ -854,7 +874,7 @@ def condition_block() -> None:
# of the register, this could be changed to a user given value

self.module.builder.call(
self.set_one_bit_in_reg,
self.set_creg_bit,
[
self.ssa_vars[outputs],
pyqir.const(self.qir_int_type, result_index),
Expand All @@ -863,7 +883,7 @@ def condition_block() -> None:
)
else:
self.module.builder.call(
self.set_all_bits_in_reg,
self.set_creg_to_int,
[self.ssa_vars[outputs], output_instruction],
)

Expand All @@ -875,7 +895,7 @@ def condition_block() -> None:
output_instruction = pyqir.const(self.qir_bool_type, int(v))

self.module.builder.call(
self.set_one_bit_in_reg,
self.set_creg_bit,
[
self.ssa_vars[b.reg_name],
pyqir.const(self.qir_int_type, b.index[0]),
Expand All @@ -890,15 +910,15 @@ def condition_block() -> None:

for i, o in zip(command.args[:half_length], command.args[half_length:]):
output_instruction = self.module.builder.call(
self.read_bit_from_reg,
self.get_creg_bit,
[
self.ssa_vars[i.reg_name], # type: ignore
pyqir.const(self.qir_int_type, i.index[0]), # type: ignore
],
)

self.module.builder.call(
self.set_one_bit_in_reg,
self.set_creg_bit,
[
self.ssa_vars[o.reg_name],
pyqir.const(self.qir_int_type, o.index[0]),
Expand Down Expand Up @@ -966,7 +986,7 @@ def condition_block() -> None:
self.module.builder.call(
self.record_output_i64,
[
self.ssa_vars[reg_name],
self._get_i64_ssa_reg(reg_name),
self.reg_const[reg_name],
],
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
packages=find_namespace_packages(include=["pytket.*"]),
include_package_data=True,
install_requires=[
"pytket ~= 1.17",
"pytket ~= 1.18",
"pyqir == 0.8.2",
"pyqir-generator == 0.7.0",
"pyqir-evaluator == 0.7.0",
Expand Down
Loading