diff --git a/pytket/phir/rebasing/rebaser.py b/pytket/phir/rebasing/rebaser.py index 6c86a26..cfc2585 100644 --- a/pytket/phir/rebasing/rebaser.py +++ b/pytket/phir/rebasing/rebaser.py @@ -3,6 +3,7 @@ from pytket.extensions.quantinuum.backends.quantinuum import ( QuantinuumBackend, ) +from pytket.passes import DecomposeBoxes def rebase_to_qtm_machine(circuit: Circuit, qtm_machine: str) -> Circuit: @@ -13,6 +14,10 @@ def rebase_to_qtm_machine(circuit: Circuit, qtm_machine: str) -> Circuit: machine_debug=False, api_handler=qapi_offline, # type: ignore [arg-type] ) + + # Decompose boxes to ensure no problematic phase gates + DecomposeBoxes().apply(circuit) + # Optimization level 0 includes rebasing and little else # see: https://cqcl.github.io/pytket-quantinuum/api/#default-compilation return backend.get_compiled_circuit(circuit, 0) diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 45e7651..66fe92e 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -14,6 +14,7 @@ OpType.SetBits, OpType.ClassicalExpBox, # some classical operations are rolled up into a box OpType.RangePredicate, + OpType.ExplicitPredicate, ] logger = logging.getLogger(__name__) @@ -46,8 +47,15 @@ def shard(self) -> list[Shard]: ------- list of Shards needed to schedule """ - logger.debug("Sharding begins....") - for command in self._circuit.get_commands(): + logger.debug("Sharding beginning") + commands = self._circuit.get_commands() + + if logger.isEnabledFor(logging.DEBUG): + logger.debug("All commands:") + for command in commands: + logger.debug(command) + + for command in commands: self._process_command(command) self._cleanup_remaining_commands() @@ -69,6 +77,10 @@ def _process_command(self, command: Command) -> None: msg = f"OpType {command.op.type} not supported!" raise NotImplementedError(msg) + if self._is_command_global_phase(command): + logger.debug("Ignoring global Phase gate") + return + if self.should_op_create_shard(command.op): logger.debug( f"Building shard for command: {command}", @@ -77,6 +89,12 @@ def _process_command(self, command: Command) -> None: else: self._add_pending_sub_command(command) + def _is_command_global_phase(self, command: Command) -> bool: + return command.op.type == OpType.Phase or ( + command.op.type == OpType.Conditional + and cast(Conditional, command.op).op.type == OpType.Phase + ) + def _build_shard(self, command: Command) -> None: """Builds a shard. @@ -165,10 +183,10 @@ def _add_pending_sub_command(self, command: Command) -> None: Args: command: tket command (operation, bits, etc) """ - key = command.qubits[0] - if key not in self._pending_commands: - self._pending_commands[key] = [] - self._pending_commands[key].append(command) + qubit_key = command.qubits[0] + if qubit_key not in self._pending_commands: + self._pending_commands[qubit_key] = [] + self._pending_commands[qubit_key].append(command) logger.debug( f"Adding pending command {command}", ) diff --git a/tests/test_api.py b/tests/test_api.py index d589f05..f2aba4d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,8 +1,14 @@ +import logging + +import pytest + from pytket.phir.api import pytket_to_phir from pytket.phir.qtm_machine import QtmMachine from .sample_data import QasmFile, get_qasm_as_circuit +logger = logging.getLogger(__name__) + class TestApi: def test_pytket_to_phir_no_machine(self) -> None: @@ -11,9 +17,9 @@ def test_pytket_to_phir_no_machine(self) -> None: assert pytket_to_phir(circuit) - def test_pytket_to_phir_h1_1(self) -> None: + @pytest.mark.parametrize("test_file", list(QasmFile)) + def test_pytket_to_phir_h1_1_all_circuits(self, test_file: QasmFile) -> None: """Standard case.""" - circuit = get_qasm_as_circuit(QasmFile.baby) + circuit = get_qasm_as_circuit(test_file) - # TODO(neal): Make this test more valuable once PHIR is actually returned assert pytket_to_phir(circuit, QtmMachine.H1_1)