Skip to content

Commit

Permalink
update registerhandling to use i1* pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
cqc-melf committed Aug 7, 2023
1 parent ce2b64e commit ca50e79
Show file tree
Hide file tree
Showing 41 changed files with 998 additions and 625 deletions.
4 changes: 3 additions & 1 deletion pytket/qir/conversion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,17 @@ def keep_line(line: str) -> bool:
return (
("@__quantum__qis__read_result__body" not in line)
and ("@set_one_bit_in_reg" not in line)
and ("@reg2var" not in line)
and ("@read_bit_from_reg" not in line)
and ("@set_all_bits_in_reg" not in line)
and ("@read_all_bits_from_reg" not in line)
and ("@create_reg" not in line)
)

result = "\n".join(filter(keep_line, initial_result.split("\n")))

# replace the use of the removed register variable with i64 0
result = result.replace("i64 %0", "i64 0")
result = result.replace("i64 %3", "i64 0")

for _ in range(10):
result = result.replace("\n\n\n\n", "\n\n")
Expand Down
103 changes: 65 additions & 38 deletions pytket/qir/conversion/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
self.module = module
self.wasm_int_type = pyqir.IntType(self.module.context, wasm_int_type)
self.qir_int_type = pyqir.IntType(self.module.context, qir_int_type)
self.qir_i1p_type = pyqir.PointerType(pyqir.IntType(self.module.context, 1))
self.qir_bool_type = pyqir.IntType(self.module.context, 1)
self.qubit_type = pyqir.qubit_type(self.module.context)
self.result_type = pyqir.result_type(self.module.context)
Expand All @@ -132,35 +133,35 @@ def __init__(
self.set_cregs: Dict[str, List] = {} # Keep track of set registers.
self.ssa_vars: Dict[str, Value] = {} # Keep track of set ssa variables.

# i1 read_bit_from_reg(i64 reg, i64 index)
# i1 read_bit_from_reg(i1* reg, i64 index)
self.read_bit_from_reg = self.module.module.add_external_function(
"read_bit_from_reg",
pyqir.FunctionType(
pyqir.IntType(self.module.module.context, 1),
[self.qir_int_type] * 2,
[self.qir_i1p_type, self.qir_int_type],
),
)

# void set_one_bit_in_reg(i64 reg, i64 index, i1 value)
# void set_one_bit_in_reg(i1* reg, i64 index, i1 value)
self.set_one_bit_in_reg = self.module.module.add_external_function(
"set_one_bit_in_reg",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
self.qir_int_type,
self.qir_i1p_type,
self.qir_int_type,
pyqir.IntType(self.module.module.context, 1),
],
),
)

# void set_all_bits_in_reg(i64 reg, i64 value)
# void set_all_bits_in_reg(i1* reg, i64 value)
self.set_all_bits_in_reg = self.module.module.add_external_function(
"set_all_bits_in_reg",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
self.qir_int_type,
self.qir_i1p_type,
self.qir_int_type,
],
),
Expand All @@ -175,12 +176,23 @@ def __init__(
),
)

# i64 reg2var(i1, i1, i1, ...)
self.reg2var = self.module.module.add_external_function(
"reg2var",
# i1* create_reg(i64 size)
self.create_reg = self.module.module.add_external_function(
"create_reg",
pyqir.FunctionType(
self.qir_i1p_type,
[pyqir.IntType(self.module.module.context, qir_int_type)],
),
)

# i64 create_reg(i1* reg)
self.read_all_bits_from_reg = self.module.module.add_external_function(
"read_all_bits_from_reg",
pyqir.FunctionType(
pyqir.IntType(self.module.module.context, qir_int_type),
[pyqir.IntType(self.module.module.context, 1)] * qir_int_type,
self.qir_int_type,
[
self.qir_i1p_type,
],
),
)

Expand Down Expand Up @@ -377,6 +389,13 @@ def _get_optype_and_params(self, op: Op) -> Tuple[OpType, Sequence[float]]:
params = op.params
return (optype, params)

def _get_i64_ssa_reg(self, name: str) -> Value:
ssa_var = self.module.builder.call(
self.read_all_bits_from_reg,
[self.ssa_vars[name]],
)
return ssa_var

def _to_qis_qubits(self, qubits: List[Qubit]) -> Sequence[Qubit]:
return [self.module.module.qubits[qubit.index[0]] for qubit in qubits]

Expand Down Expand Up @@ -405,11 +424,11 @@ def _reg2ssa_var(self, bit_reg: BitRegister, int_size: int) -> Value:
return pyqir.const(self.qir_int_type, value)
else:
bit_reg = [False] * len(bit_reg)
if (size := len(bit_reg)) <= int_size: # Widening by zero-padding.
bool_reg = bit_reg + [False] * (int_size - size)
else: # Narrowing by truncation.
bool_reg = bit_reg[:int_size]
ssa_var = cast(Value, self.module.builder.call(self.reg2var, [*bool_reg])) # type: ignore
if len(bit_reg) > int_size:
raise ValueError(
f"Classical register should only have the size of {int_size}"
)
ssa_var = cast(Value, self.module.builder.call(self.create_reg, [pyqir.const(self.qir_int_type, len(bit_reg))])) # type: ignore
self.ssa_vars[reg_name] = ssa_var
return ssa_var
else:
Expand Down Expand Up @@ -453,7 +472,7 @@ def _get_ssa_from_cl_reg_op(
)
return output_instruction # type: ignore
elif type(reg) == BitRegister:
return self.ssa_vars[reg.name]
return self._get_i64_ssa_reg(reg.name)
elif type(reg) == int:
return pyqir.const(self.qir_int_type, reg)
else:
Expand Down Expand Up @@ -506,7 +525,7 @@ def circuit_to_module(
result = module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.lower),
self.ssa_vars[registername],
self._get_i64_ssa_reg(registername),
)

condition_bit_index = command.args[-1].index[0]
Expand All @@ -529,10 +548,14 @@ def circuit_to_module(
registername = command.args[0].reg_name

lower_cond = module.module.builder.icmp(
pyqir.IntPredicate.SGT, lower_qir, self.ssa_vars[registername]
pyqir.IntPredicate.SGT,
lower_qir,
self._get_i64_ssa_reg(registername),
)
upper_cond = module.module.builder.icmp(
pyqir.IntPredicate.SGT, self.ssa_vars[registername], upper_qir
pyqir.IntPredicate.SGT,
self._get_i64_ssa_reg(registername),
upper_qir,
)

result = module.module.builder.and_(lower_cond, upper_cond)
Expand Down Expand Up @@ -620,7 +643,7 @@ def condition_block() -> None:
ssabool = module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.value),
self.ssa_vars[condition_name],
self._get_i64_ssa_reg(condition_name),
)

module.module.builder.if_(
Expand Down Expand Up @@ -809,18 +832,18 @@ def condition_block() -> None:
0 # defines the default value for ops that returns bool, see below
)
outputs = command.args[-1].reg_name
ssa_left = self.ssa_vars[list(self.ssa_vars)[0]] # set default value
ssa_right = self.ssa_vars[list(self.ssa_vars)[0]] # set default value

#
ssa_left = (self._get_i64_ssa_reg(list(self.ssa_vars)[0]),)
ssa_right = (self._get_i64_ssa_reg(list(self.ssa_vars)[0]),)

if type(op.get_exp()) in _TK_CLOPS_TO_PYQIR_REG:
# classical ops acting on registers returning register
ssa_left = self._get_ssa_from_cl_reg_op(
op.get_exp().args[0], module
ssa_left = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[0], module),
)
ssa_right = self._get_ssa_from_cl_reg_op(
op.get_exp().args[1], module
ssa_right = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[1], module),
)

# add function to module
Expand All @@ -830,11 +853,13 @@ def condition_block() -> None:

elif type(op.get_exp()) in _TK_CLOPS_TO_PYQIR_BIT:
# classical ops acting on bits returning bit
ssa_left = self._get_ssa_from_cl_bit_op(
op.get_exp().args[0], module
ssa_left = cast( # type: ignore
Value,
self._get_ssa_from_cl_bit_op(op.get_exp().args[0], module),
)
ssa_right = self._get_ssa_from_cl_bit_op(
op.get_exp().args[1], module
ssa_right = cast( # type: ignore
Value,
self._get_ssa_from_cl_bit_op(op.get_exp().args[1], module),
)

# add function to module
Expand All @@ -846,11 +871,13 @@ def condition_block() -> None:

elif type(op.get_exp()) in _TK_CLOPS_TO_PYQIR_REG_BOOL:
# classical ops acting on registers returning bit
ssa_left = self._get_ssa_from_cl_reg_op(
op.get_exp().args[0], module
ssa_left = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[0], module),
)
ssa_right = self._get_ssa_from_cl_reg_op(
op.get_exp().args[1], module
ssa_right = cast( # type: ignore
Value,
self._get_ssa_from_cl_reg_op(op.get_exp().args[1], module),
)

# add function to module
Expand Down Expand Up @@ -984,7 +1011,7 @@ def condition_block() -> None:
self.module.builder.call(
self.record_output_i64,
[
self.ssa_vars[reg_name],
self._get_i64_ssa_reg(reg_name),
self.reg_const[reg_name],
],
)
Expand Down
10 changes: 0 additions & 10 deletions tests/conditional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,8 @@ def test_pytket_qir_conditional_v() -> None:
circ, name="test_pytket_qir_conditional_v", qir_format=QIRFormat.STRING
)

circ = Circuit(2, 2).H(0).H(1).measure_all()

circ.add_gate(OpType.H, [0], condition_bits=[0, 1], condition_value=3)

result_2 = pytket_to_qir(
circ, name="test_pytket_qir_conditional_v", qir_format=QIRFormat.STRING
)

check_qir_result(result, "test_pytket_qir_conditional_v")

check_qir_result(result_2, "test_pytket_qir_conditional_v")


def test_pytket_qir_conditional_6() -> None:
# test conditional for manual added gates
Expand Down
6 changes: 2 additions & 4 deletions tests/conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,9 @@ def test_pytket_qir_11() -> None:

c.add_c_copyreg(a, b)

result = pytket_to_qir(c, name="test_pytket_qir_10", qir_format=QIRFormat.STRING)
result = pytket_to_qir(c, name="test_pytket_qir_11", qir_format=QIRFormat.STRING)

check_qir_result(
result, "test_pytket_qir_10"
) # should be identical to the testcase above
check_qir_result(result, "test_pytket_qir_11")


def test_pytket_qir_12() -> None:
Expand Down
10 changes: 6 additions & 4 deletions tests/qir/test_pytket_qir.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ entry:
ret void
}

declare i1 @read_bit_from_reg(i64, i64)
declare i1 @read_bit_from_reg(i1*, i64)

declare void @set_one_bit_in_reg(i64, i64, i1)
declare void @set_one_bit_in_reg(i1*, i64, i1)

declare void @set_all_bits_in_reg(i64, i64)
declare void @set_all_bits_in_reg(i1*, i64)

declare i1 @__quantum__qis__read_result__body(%Result*)

declare i64 @reg2var(i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1)
declare i1* @create_reg(i64)

declare i64 @read_all_bits_from_reg(i1*)

declare void @__quantum__rt__int_record_output(i64, i8*)

Expand Down
28 changes: 16 additions & 12 deletions tests/qir/test_pytket_qir_10.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,32 @@ source_filename = "test_pytket_qir_10"

define void @main() #0 {
entry:
%0 = call i64 @reg2var(i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false)
%1 = call i64 @reg2var(i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false)
%2 = call i1 @read_bit_from_reg(i64 %0, i64 0)
call void @set_one_bit_in_reg(i64 %1, i64 0, i1 %2)
%3 = call i1 @read_bit_from_reg(i64 %0, i64 1)
call void @set_one_bit_in_reg(i64 %1, i64 1, i1 %3)
%0 = call i1* @create_reg(i64 4)
%1 = call i1* @create_reg(i64 2)
%2 = call i1 @read_bit_from_reg(i1* %0, i64 0)
call void @set_one_bit_in_reg(i1* %1, i64 0, i1 %2)
%3 = call i1 @read_bit_from_reg(i1* %0, i64 1)
call void @set_one_bit_in_reg(i1* %1, i64 1, i1 %3)
call void @__quantum__rt__tuple_start_record_output()
call void @__quantum__rt__int_record_output(i64 %0, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @0, i32 0, i32 0))
call void @__quantum__rt__int_record_output(i64 %1, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @1, i32 0, i32 0))
%4 = call i64 @read_all_bits_from_reg(i1* %0)
call void @__quantum__rt__int_record_output(i64 %4, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @0, i32 0, i32 0))
%5 = call i64 @read_all_bits_from_reg(i1* %1)
call void @__quantum__rt__int_record_output(i64 %5, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @1, i32 0, i32 0))
call void @__quantum__rt__tuple_end_record_output()
ret void
}

declare i1 @read_bit_from_reg(i64, i64)
declare i1 @read_bit_from_reg(i1*, i64)

declare void @set_one_bit_in_reg(i64, i64, i1)
declare void @set_one_bit_in_reg(i1*, i64, i1)

declare void @set_all_bits_in_reg(i64, i64)
declare void @set_all_bits_in_reg(i1*, i64)

declare i1 @__quantum__qis__read_result__body(%Result*)

declare i64 @reg2var(i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1)
declare i1* @create_reg(i64)

declare i64 @read_all_bits_from_reg(i1*)

declare void @__quantum__rt__int_record_output(i64, i8*)

Expand Down
51 changes: 51 additions & 0 deletions tests/qir/test_pytket_qir_11.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
; ModuleID = 'test_pytket_qir_11'
source_filename = "test_pytket_qir_11"

%Result = type opaque

@0 = internal constant [2 x i8] c"a\00"
@1 = internal constant [2 x i8] c"b\00"

define void @main() #0 {
entry:
%0 = call i1* @create_reg(i64 2)
%1 = call i1* @create_reg(i64 4)
%2 = call i1 @read_bit_from_reg(i1* %0, i64 0)
call void @set_one_bit_in_reg(i1* %1, i64 0, i1 %2)
%3 = call i1 @read_bit_from_reg(i1* %0, i64 1)
call void @set_one_bit_in_reg(i1* %1, i64 1, i1 %3)
call void @__quantum__rt__tuple_start_record_output()
%4 = call i64 @read_all_bits_from_reg(i1* %0)
call void @__quantum__rt__int_record_output(i64 %4, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @0, i32 0, i32 0))
%5 = call i64 @read_all_bits_from_reg(i1* %1)
call void @__quantum__rt__int_record_output(i64 %5, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @1, i32 0, i32 0))
call void @__quantum__rt__tuple_end_record_output()
ret void
}

declare i1 @read_bit_from_reg(i1*, i64)

declare void @set_one_bit_in_reg(i1*, i64, i1)

declare void @set_all_bits_in_reg(i1*, i64)

declare i1 @__quantum__qis__read_result__body(%Result*)

declare i1* @create_reg(i64)

declare i64 @read_all_bits_from_reg(i1*)

declare void @__quantum__rt__int_record_output(i64, i8*)

declare void @__quantum__rt__tuple_start_record_output()

declare void @__quantum__rt__tuple_end_record_output()

attributes #0 = { "entry_point" "num_required_qubits"="1" "num_required_results"="1" "output_labeling_schema" "qir_profiles"="custom" }

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
Loading

0 comments on commit ca50e79

Please sign in to comment.