diff --git a/CHANGELOG.md b/CHANGELOG.md index 60716fe6..535865d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## [ 0.3.0 ] - [ xxxx-yy-zz ] +### Added + +- `NativeGateValidator` validator pass + ## [ 0.2.0 ] - [ 2025-01-21 ] diff --git a/opensquirrel/circuit.py b/opensquirrel/circuit.py index 65eb118b..f6a920db 100644 --- a/opensquirrel/circuit.py +++ b/opensquirrel/circuit.py @@ -9,8 +9,9 @@ from opensquirrel.ir import IR, Gate from opensquirrel.passes.decomposer import Decomposer from opensquirrel.passes.mapper import Mapper - from opensquirrel.passes.merger.general_merger import Merger - from opensquirrel.passes.router.general_router import Router + from opensquirrel.passes.merger import Merger + from opensquirrel.passes.router import Router + from opensquirrel.passes.validator import Validator from opensquirrel.register_manager import RegisterManager @@ -87,8 +88,12 @@ def qubit_register_name(self) -> str: def bit_register_name(self) -> str: return self.register_manager.get_bit_register_name() + def validate(self, validator: Validator) -> None: + """Generic validator pass. It applies the given validator to the circuit.""" + validator.validate(self.ir) + def route(self, router: Router) -> None: - """Generic router pass. It applies the given Router to the circuit.""" + """Generic router pass. It applies the given router to the circuit.""" router.route(self.ir) def merge(self, merger: Merger) -> None: diff --git a/opensquirrel/ir.py b/opensquirrel/ir.py index 344b8ca0..85f45b82 100644 --- a/opensquirrel/ir.py +++ b/opensquirrel/ir.py @@ -519,9 +519,13 @@ def name(self) -> str: return self.generator.__name__ return "Anonymous gate: " + self.__repr__() + @property + def is_named_gate(self) -> bool: + return not (self.generator is None or self.generator.__name__ is None) + @property def is_anonymous(self) -> bool: - return self.generator is None + return not self.is_named_gate @staticmethod def _check_repeated_qubit_operands(qubits: Sequence[Qubit]) -> bool: diff --git a/opensquirrel/passes/validator/__init__.py b/opensquirrel/passes/validator/__init__.py new file mode 100644 index 00000000..c88cc843 --- /dev/null +++ b/opensquirrel/passes/validator/__init__.py @@ -0,0 +1,6 @@ +"""Init file for the validator passes.""" + +from opensquirrel.passes.validator.general_validator import Validator +from opensquirrel.passes.validator.native_gate_validator import NativeGateValidator + +__all__ = ["NativeGateValidator", "Validator"] diff --git a/opensquirrel/passes/validator/general_validator.py b/opensquirrel/passes/validator/general_validator.py new file mode 100644 index 00000000..fd9da472 --- /dev/null +++ b/opensquirrel/passes/validator/general_validator.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + +from opensquirrel.ir import IR + + +class Validator(ABC): + @abstractmethod + def validate(self, ir: IR) -> None: + """Base validate method to be implemented by inheriting validator classes.""" + raise NotImplementedError diff --git a/opensquirrel/passes/validator/native_gate_validator.py b/opensquirrel/passes/validator/native_gate_validator.py new file mode 100644 index 00000000..a4bc527a --- /dev/null +++ b/opensquirrel/passes/validator/native_gate_validator.py @@ -0,0 +1,26 @@ +from opensquirrel.ir import IR, Unitary +from opensquirrel.passes.validator import Validator + + +class NativeGateValidator(Validator): + def __init__(self, native_gate_set: list[str]) -> None: + self.native_gate_set = native_gate_set + + def validate(self, ir: IR) -> None: + """ + Check if all unitary gates in the circuit are part of the native gate set. + + Args: + ir (IR): The intermediate representation of the circuit to be checked. + + Raises: + ValueError: If any unitary gate in the circuit is not part of the native gate set. + """ + gates_not_in_native_gate_set = [ + statement.name + for statement in ir.statements + if isinstance(statement, Unitary) and statement.name not in self.native_gate_set + ] + if gates_not_in_native_gate_set: + error_message = f"The following gates are not in the native gate set: {set(gates_not_in_native_gate_set)}" + raise ValueError(error_message) diff --git a/test/test_integration.py b/test/test_integration.py index 71883054..6ffd3adf 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -17,6 +17,7 @@ from opensquirrel.passes.exporter import ExportFormat from opensquirrel.passes.merger import SingleQubitGatesMerger from opensquirrel.passes.router import RoutingChecker +from opensquirrel.passes.validator import NativeGateValidator def test_spin2plus_backend() -> None: @@ -47,8 +48,12 @@ def test_spin2plus_backend() -> None: # Check whether the above algorithm can be mapped to a dummy chip topology connectivity = {"0": [1], "1": [0]} + native_gate_set = ["I", "X90", "mX90", "Y90", "mY90", "Rz", "CZ"] - qc.route(router=RoutingChecker(connectivity=connectivity)) + qc.route(router=RoutingChecker(connectivity)) + + # Decompose 2-qubit gates to a decomposition where the 2-qubit interactions are captured by CNOT gates + qc.decompose(decomposer=CNOTDecomposer()) qc.decompose(decomposer=SWAP2CNOTDecomposer()) @@ -64,6 +69,9 @@ def test_spin2plus_backend() -> None: # Decompose single-qubit gates to spin backend native gates with McKay decomposer qc.decompose(decomposer=McKayDecomposer()) + # Check whether the gates in the circuit match the native gate set of the backend + qc.validate(validator=NativeGateValidator(native_gate_set)) + assert ( str(qc) == """version 3.0 diff --git a/test/validator/test_native_gate_validator.py b/test/validator/test_native_gate_validator.py new file mode 100644 index 00000000..72f4ddeb --- /dev/null +++ b/test/validator/test_native_gate_validator.py @@ -0,0 +1,52 @@ +# Tests for native gate checker pass +import pytest + +from opensquirrel import CircuitBuilder +from opensquirrel.circuit import Circuit +from opensquirrel.passes.validator import NativeGateValidator + + +@pytest.fixture(name="validator") +def validator_fixture() -> NativeGateValidator: + native_gate_set = ["I", "X90", "mX90", "Y90", "mY90", "Rz", "CZ"] + return NativeGateValidator(native_gate_set) + + +@pytest.fixture +def circuit_with_matching_gate_set() -> Circuit: + builder = CircuitBuilder(5) + builder.I(0) + builder.X90(1) + builder.mX90(2) + builder.Y90(3) + builder.mY90(4) + builder.Rz(0, 2) + builder.CZ(1, 2) + return builder.to_circuit() + + +@pytest.fixture +def circuit_with_unmatching_gate_set() -> Circuit: + builder = CircuitBuilder(5) + builder.I(0) + builder.X90(1) + builder.mX90(2) + builder.Y90(3) + builder.mY90(4) + builder.Rz(0, 2) + builder.CZ(1, 2) + builder.H(0) + builder.CNOT(1, 2) + return builder.to_circuit() + + +def test_matching_gates(validator: NativeGateValidator, circuit_with_matching_gate_set: Circuit) -> None: + try: + validator.validate(circuit_with_matching_gate_set.ir) + except ValueError: + pytest.fail("validate() raised ValueError unexpectedly") + + +def test_non_matching_gates(validator: NativeGateValidator, circuit_with_unmatching_gate_set: Circuit) -> None: + with pytest.raises(ValueError, match="The following gates are not in the native gate set:.*"): + validator.validate(circuit_with_unmatching_gate_set.ir)