Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update registerhandling to use i1* pointer #61

Merged
merged 11 commits into from
Aug 9, 2023
4 changes: 3 additions & 1 deletion pytket/qir/conversion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,17 @@ def keep_line(line: str) -> bool:
return (
("@__quantum__qis__read_result__body" not in line)
and ("@set_one_bit_in_reg" not in line)
and ("@reg2var" not in line)
and ("@read_bit_from_reg" not in line)
and ("@set_all_bits_in_reg" not in line)
and ("@read_all_bits_from_reg" not in line)
and ("@create_reg" not in line)
)

result = "\n".join(filter(keep_line, initial_result.split("\n")))

# replace the use of the removed register variable with i64 0
result = result.replace("i64 %0", "i64 0")
result = result.replace("i64 %3", "i64 0")

for _ in range(10):
result = result.replace("\n\n\n\n", "\n\n")
Expand Down
116 changes: 68 additions & 48 deletions pytket/qir/conversion/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
self.module = module
self.wasm_int_type = pyqir.IntType(self.module.context, wasm_int_type)
self.qir_int_type = pyqir.IntType(self.module.context, qir_int_type)
self.qir_i1p_type = pyqir.PointerType(pyqir.IntType(self.module.context, 1))
self.qir_bool_type = pyqir.IntType(self.module.context, 1)
self.qubit_type = pyqir.qubit_type(self.module.context)
self.result_type = pyqir.result_type(self.module.context)
Expand All @@ -129,35 +130,35 @@ def __init__(
self.set_cregs: dict[str, list] = {} # Keep track of set registers.
self.ssa_vars: dict[str, Value] = {} # Keep track of set ssa variables.

# i1 read_bit_from_reg(i64 reg, i64 index)
# i1 read_bit_from_reg(i1* reg, i64 index)
qartik marked this conversation as resolved.
Show resolved Hide resolved
self.read_bit_from_reg = self.module.module.add_external_function(
"read_bit_from_reg",
pyqir.FunctionType(
pyqir.IntType(self.module.module.context, 1),
[self.qir_int_type] * 2,
[self.qir_i1p_type, self.qir_int_type],
),
)

# void set_one_bit_in_reg(i64 reg, i64 index, i1 value)
# void set_one_bit_in_reg(i1* reg, i64 index, i1 value)
qartik marked this conversation as resolved.
Show resolved Hide resolved
self.set_one_bit_in_reg = self.module.module.add_external_function(
"set_one_bit_in_reg",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
self.qir_int_type,
self.qir_i1p_type,
self.qir_int_type,
pyqir.IntType(self.module.module.context, 1),
],
),
)

# void set_all_bits_in_reg(i64 reg, i64 value)
# void set_all_bits_in_reg(i1* reg, i64 value)
qartik marked this conversation as resolved.
Show resolved Hide resolved
self.set_all_bits_in_reg = self.module.module.add_external_function(
"set_all_bits_in_reg",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
self.qir_int_type,
self.qir_i1p_type,
self.qir_int_type,
],
),
Expand All @@ -172,12 +173,23 @@ def __init__(
),
)

# i64 reg2var(i1, i1, i1, ...)
self.reg2var = self.module.module.add_external_function(
"reg2var",
# i1* create_reg(i64 size)
qartik marked this conversation as resolved.
Show resolved Hide resolved
self.create_reg = self.module.module.add_external_function(
"create_reg",
pyqir.FunctionType(
pyqir.IntType(self.module.module.context, qir_int_type),
[pyqir.IntType(self.module.module.context, 1)] * qir_int_type,
self.qir_i1p_type,
[pyqir.IntType(self.module.module.context, qir_int_type)],
),
)

# i64 create_reg(i1* reg)
qartik marked this conversation as resolved.
Show resolved Hide resolved
self.read_all_bits_from_reg = self.module.module.add_external_function(
"read_all_bits_from_reg",
pyqir.FunctionType(
self.qir_int_type,
[
self.qir_i1p_type,
],
),
)

Expand Down Expand Up @@ -374,6 +386,13 @@ def _get_optype_and_params(self, op: Op) -> tuple[OpType, Sequence[float]]:
params = op.params
return (optype, params)

def _get_i64_ssa_reg(self, name: str) -> Value:
ssa_var = self.module.builder.call(
self.read_all_bits_from_reg,
[self.ssa_vars[name]],
)
return ssa_var

def _to_qis_qubits(self, qubits: list[Qubit]) -> Sequence[Qubit]:
return [self.module.module.qubits[qubit.index[0]] for qubit in qubits]

Expand All @@ -392,21 +411,14 @@ def _to_qis_bits(self, args: list[Bit]) -> Sequence[Value]:
def _reg2ssa_var(self, bit_reg: BitRegister, int_size: int) -> Value:
"""Convert a BitRegister to an SSA variable using pyqir types."""
reg_name = bit_reg[0].reg_name
if (
reg_name not in self.ssa_vars
): # Check if the register has been previously set.
# Check if the register has been previously set. If not, initialise to 0.
if reg_value := self.set_cregs.get(reg_name):
bit_reg = reg_value
value = sum([n * 2**k for k, n in enumerate(reg_value)])
return pyqir.const(self.qir_int_type, value)
else:
bit_reg = [False] * len(bit_reg)
if (size := len(bit_reg)) <= int_size: # Widening by zero-padding.
bool_reg = bit_reg + [False] * (int_size - size)
else: # Narrowing by truncation.
bool_reg = bit_reg[:int_size]
ssa_var = cast(Value, self.module.builder.call(self.reg2var, [*bool_reg])) # type: ignore # noqa: E501
if reg_name not in self.ssa_vars:
if len(bit_reg) > int_size:
raise ValueError(
f"Classical register should only have the size of {int_size}"
)
ssa_var = self.module.builder.call( # type: ignore
self.create_reg, [pyqir.const(self.qir_int_type, len(bit_reg))]
)
self.ssa_vars[reg_name] = ssa_var
return ssa_var
else:
Expand Down Expand Up @@ -449,7 +461,7 @@ def _get_ssa_from_cl_reg_op(
)
return output_instruction # type: ignore
elif type(reg) == BitRegister:
return self.ssa_vars[reg.name]
return self._get_i64_ssa_reg(reg.name)
elif type(reg) == int:
return pyqir.const(self.qir_int_type, reg)
else:
Expand Down Expand Up @@ -499,7 +511,7 @@ def circuit_to_module(
result = module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.lower),
self.ssa_vars[registername],
self._get_i64_ssa_reg(registername),
)

condition_bit_index = command.args[-1].index[0]
Expand All @@ -521,10 +533,14 @@ def circuit_to_module(
registername = command.args[0].reg_name

lower_cond = module.module.builder.icmp(
pyqir.IntPredicate.SGT, lower_qir, self.ssa_vars[registername]
pyqir.IntPredicate.SGT,
lower_qir,
self._get_i64_ssa_reg(registername),
)
upper_cond = module.module.builder.icmp(
pyqir.IntPredicate.SGT, self.ssa_vars[registername], upper_qir
pyqir.IntPredicate.SGT,
self._get_i64_ssa_reg(registername),
upper_qir,
)

result = module.module.builder.and_(lower_cond, upper_cond)
Expand Down Expand Up @@ -610,7 +626,7 @@ def condition_block() -> None:
ssabool = module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.value),
self.ssa_vars[condition_name],
self._get_i64_ssa_reg(condition_name),
)

module.module.builder.if_(
Expand Down Expand Up @@ -793,18 +809,18 @@ def condition_block() -> None:
0 # defines the default value for ops that returns bool, see below
)
outputs = command.args[-1].reg_name
ssa_left = self.ssa_vars[list(self.ssa_vars)[0]] # set default value
ssa_right = self.ssa_vars[list(self.ssa_vars)[0]] # set default value

#
ssa_left = (self._get_i64_ssa_reg(list(self.ssa_vars)[0]),)
ssa_right = (self._get_i64_ssa_reg(list(self.ssa_vars)[0]),)

if type(op.get_exp()) in _TK_CLOPS_TO_PYQIR_REG:
# classical ops acting on registers returning register
ssa_left = self._get_ssa_from_cl_reg_op(
op.get_exp().args[0], module
ssa_left = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[0], module),
)
ssa_right = self._get_ssa_from_cl_reg_op(
op.get_exp().args[1], module
ssa_right = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[1], module),
)

# add function to module
Expand All @@ -814,11 +830,13 @@ def condition_block() -> None:

elif type(op.get_exp()) in _TK_CLOPS_TO_PYQIR_BIT:
# classical ops acting on bits returning bit
ssa_left = self._get_ssa_from_cl_bit_op(
op.get_exp().args[0], module
ssa_left = cast( # type: ignore
Value,
self._get_ssa_from_cl_bit_op(op.get_exp().args[0], module),
)
ssa_right = self._get_ssa_from_cl_bit_op(
op.get_exp().args[1], module
ssa_right = cast( # type: ignore
Value,
self._get_ssa_from_cl_bit_op(op.get_exp().args[1], module),
)

# add function to module
Expand All @@ -830,11 +848,13 @@ def condition_block() -> None:

elif type(op.get_exp()) in _TK_CLOPS_TO_PYQIR_REG_BOOL:
# classical ops acting on registers returning bit
ssa_left = self._get_ssa_from_cl_reg_op(
op.get_exp().args[0], module
ssa_left = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[0], module),
)
ssa_right = self._get_ssa_from_cl_reg_op(
op.get_exp().args[1], module
ssa_right = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[1], module),
)

# add function to module
Expand Down Expand Up @@ -966,7 +986,7 @@ def condition_block() -> None:
self.module.builder.call(
self.record_output_i64,
[
self.ssa_vars[reg_name],
self._get_i64_ssa_reg(reg_name),
self.reg_const[reg_name],
],
)
Expand Down
10 changes: 0 additions & 10 deletions tests/conditional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,8 @@ def test_pytket_qir_conditional_v() -> None:
circ, name="test_pytket_qir_conditional_v", qir_format=QIRFormat.STRING
)

circ = Circuit(2, 2).H(0).H(1).measure_all()

circ.add_gate(OpType.H, [0], condition_bits=[0, 1], condition_value=3)

result_2 = pytket_to_qir(
circ, name="test_pytket_qir_conditional_v", qir_format=QIRFormat.STRING
)

check_qir_result(result, "test_pytket_qir_conditional_v")

check_qir_result(result_2, "test_pytket_qir_conditional_v")


def test_pytket_qir_conditional_6() -> None:
# test conditional for manual added gates
Expand Down
6 changes: 2 additions & 4 deletions tests/conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,9 @@ def test_pytket_qir_11() -> None:

c.add_c_copyreg(a, b)

result = pytket_to_qir(c, name="test_pytket_qir_10", qir_format=QIRFormat.STRING)
result = pytket_to_qir(c, name="test_pytket_qir_11", qir_format=QIRFormat.STRING)

check_qir_result(
result, "test_pytket_qir_10"
) # should be identical to the testcase above
check_qir_result(result, "test_pytket_qir_11")


def test_pytket_qir_12() -> None:
Expand Down
10 changes: 6 additions & 4 deletions tests/qir/test_pytket_qir.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ entry:
ret void
}

declare i1 @read_bit_from_reg(i64, i64)
declare i1 @read_bit_from_reg(i1*, i64)

declare void @set_one_bit_in_reg(i64, i64, i1)
declare void @set_one_bit_in_reg(i1*, i64, i1)

declare void @set_all_bits_in_reg(i64, i64)
declare void @set_all_bits_in_reg(i1*, i64)

declare i1 @__quantum__qis__read_result__body(%Result*)

declare i64 @reg2var(i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1)
declare i1* @create_reg(i64)

declare i64 @read_all_bits_from_reg(i1*)

declare void @__quantum__rt__int_record_output(i64, i8*)

Expand Down
28 changes: 16 additions & 12 deletions tests/qir/test_pytket_qir_10.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,32 @@ source_filename = "test_pytket_qir_10"

define void @main() #0 {
entry:
%0 = call i64 @reg2var(i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false)
%1 = call i64 @reg2var(i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false)
%2 = call i1 @read_bit_from_reg(i64 %0, i64 0)
call void @set_one_bit_in_reg(i64 %1, i64 0, i1 %2)
%3 = call i1 @read_bit_from_reg(i64 %0, i64 1)
call void @set_one_bit_in_reg(i64 %1, i64 1, i1 %3)
%0 = call i1* @create_reg(i64 4)
%1 = call i1* @create_reg(i64 2)
%2 = call i1 @read_bit_from_reg(i1* %0, i64 0)
call void @set_one_bit_in_reg(i1* %1, i64 0, i1 %2)
%3 = call i1 @read_bit_from_reg(i1* %0, i64 1)
call void @set_one_bit_in_reg(i1* %1, i64 1, i1 %3)
call void @__quantum__rt__tuple_start_record_output()
call void @__quantum__rt__int_record_output(i64 %0, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @0, i32 0, i32 0))
call void @__quantum__rt__int_record_output(i64 %1, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @1, i32 0, i32 0))
%4 = call i64 @read_all_bits_from_reg(i1* %0)
call void @__quantum__rt__int_record_output(i64 %4, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @0, i32 0, i32 0))
%5 = call i64 @read_all_bits_from_reg(i1* %1)
call void @__quantum__rt__int_record_output(i64 %5, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @1, i32 0, i32 0))
call void @__quantum__rt__tuple_end_record_output()
ret void
}

declare i1 @read_bit_from_reg(i64, i64)
declare i1 @read_bit_from_reg(i1*, i64)

declare void @set_one_bit_in_reg(i64, i64, i1)
declare void @set_one_bit_in_reg(i1*, i64, i1)

declare void @set_all_bits_in_reg(i64, i64)
declare void @set_all_bits_in_reg(i1*, i64)

declare i1 @__quantum__qis__read_result__body(%Result*)

declare i64 @reg2var(i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1)
declare i1* @create_reg(i64)

declare i64 @read_all_bits_from_reg(i1*)

declare void @__quantum__rt__int_record_output(i64, i8*)

Expand Down
Loading