Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial register management implementation #176

Merged
merged 16 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions opensquirrel/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ def decompose(self, decomposer: Decomposer):

def map(self, mapper: Mapper) -> None:
"""Generic qubit mapper pass.
Update the register manager's mapping with a given mapper's mapping.
Map the (virtual) qubits of the circuit to the physical qubits of the target hardware.
"""
self.register_manager.mapping = mapper.get_mapping()
from opensquirrel.reindexer import reindex_circuit
reindex_circuit(self, mapper.get_mapping())

def replace(self, gate_generator: Callable[..., Gate], f):
"""Manually replace occurrences of a given gate with a list of gates.
Expand Down
14 changes: 13 additions & 1 deletion opensquirrel/mapper/mapping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Tuple


class Mapping:
Expand All @@ -24,5 +24,17 @@ def __eq__(self, other):
def __getitem__(self, key: int) -> int:
return self.data[key]

def __len__(self) -> int:
return len(self.data)

def size(self) -> int:
return len(self.data)

def items(self) -> List[Tuple[int, int]]:
return self.data.items()

def keys(self) -> List[int]:
return list(self.data.keys())
juanboschero marked this conversation as resolved.
Show resolved Hide resolved

def values(self) -> List[int]:
return list(self.data.values())
4 changes: 2 additions & 2 deletions opensquirrel/reindexer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from opensquirrel.reindexer.qubit_reindexer import _QubitReIndexer, get_reindexed_circuit
from opensquirrel.reindexer.qubit_reindexer import _QubitReindexer, get_reindexed_circuit, reindex_circuit

__all__ = ["_QubitReIndexer", "get_reindexed_circuit"]
__all__ = ["_QubitReindexer", "get_reindexed_circuit", "reindex_circuit"]
38 changes: 28 additions & 10 deletions opensquirrel/reindexer/qubit_reindexer.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
from typing import List

from opensquirrel.circuit import Circuit
from opensquirrel.mapper.mapping import Mapping
from opensquirrel.register_manager import RegisterManager
from opensquirrel.squirrel_ir import (
BlochSphereRotation,
Comment,
ControlledGate,
Gate,
MatrixGate,
Measure,
Qubit,
SquirrelIR,
SquirrelIRVisitor,
)


class _QubitReIndexer(SquirrelIRVisitor):
class _QubitReindexer(SquirrelIRVisitor):
def __init__(self, qubit_indices: List[int]):
self.qubit_indices = qubit_indices

def visit_comment(self, comment: Comment):
return comment

def visit_bloch_sphere_rotation(self, g: BlochSphereRotation):
result = BlochSphereRotation(
return BlochSphereRotation(
qubit=Qubit(self.qubit_indices.index(g.qubit.index)), angle=g.angle, axis=g.axis, phase=g.phase
)
return result

def visit_matrix_gate(self, g: MatrixGate):
mapped_operands = [Qubit(self.qubit_indices.index(op.index)) for op in g.operands]
result = MatrixGate(matrix=g.matrix, operands=mapped_operands)
reindexed_operands = [Qubit(self.qubit_indices.index(op.index)) for op in g.operands]
result = MatrixGate(matrix=g.matrix, operands=reindexed_operands)
return result

def visit_controlled_gate(self, controlled_gate: ControlledGate):
Expand All @@ -34,13 +39,26 @@ def visit_controlled_gate(self, controlled_gate: ControlledGate):
result = ControlledGate(control_qubit=control_qubit, target_gate=target_gate)
return result

def visit_measure(self, measure: Measure):
return Measure(
qubit=Qubit(self.qubit_indices.index(measure.qubit.index)), axis=measure.axis
)


def get_reindexed_circuit(replacement: List[Gate], qubit_indices: List[int]) -> Circuit:
def get_reindexed_circuit(replacement_gates: List[Gate], qubit_indices: List[int]) -> Circuit:
qubit_reindexer = _QubitReindexer(qubit_indices)
register_manager = RegisterManager(qubit_register_size=len(qubit_indices))
replacement_ir = SquirrelIR()
qubit_re_indexer = _QubitReIndexer(qubit_indices)
for gate in replacement:
gate_with_reindexed_qubits = gate.accept(qubit_re_indexer)
for gate in replacement_gates:
gate_with_reindexed_qubits = gate.accept(qubit_reindexer)
replacement_ir.add_gate(gate_with_reindexed_qubits)

return Circuit(register_manager, replacement_ir)


def reindex_circuit(circuit: Circuit, mapping: Mapping) -> None:
qubit_reindexer = _QubitReindexer(mapping.values())
replacement_ir = SquirrelIR()
for statement in circuit.squirrel_ir.statements:
statement = statement.accept(qubit_reindexer)
replacement_ir.statements.append(statement)
circuit.squirrel_ir = replacement_ir
28 changes: 14 additions & 14 deletions opensquirrel/squirrel_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __eq__(self, other):
return self.qubit == other.qubit and np.allclose(self.axis, other.axis, atol=ATOL)

def accept(self, visitor: SquirrelIRVisitor):
juanboschero marked this conversation as resolved.
Show resolved Hide resolved
visitor.visit_measure(self)
return visitor.visit_measure(self)

def get_qubit_operands(self) -> List[Qubit]:
return [self.qubit]
Expand All @@ -132,7 +132,7 @@ def __init__(self, generator, arguments):
def __eq__(self, other):
if not isinstance(other, Gate):
return False
return _compare_gate_classes(self, other)
return compare_gates(self, other)

@property
def name(self) -> Optional[str]:
Expand Down Expand Up @@ -259,18 +259,6 @@ def is_identity(self) -> bool:
return self.target_gate.is_identity()


def _compare_gate_classes(g1: Gate, g2: Gate) -> bool:
union_mapping = [q.index for q in list(set(g1.get_qubit_operands()) | set(g2.get_qubit_operands()))]

from opensquirrel.circuit_matrix_calculator import get_circuit_matrix
from opensquirrel.reindexer import get_reindexed_circuit

matrix_g1 = get_circuit_matrix(get_reindexed_circuit([g1], union_mapping))
matrix_g2 = get_circuit_matrix(get_reindexed_circuit([g2], union_mapping))

return are_matrices_equivalent_up_to_global_phase(matrix_g1, matrix_g2)


def named_gate(gate_generator: Callable[..., Gate]) -> Callable[..., Gate]:
@wraps(gate_generator)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -319,6 +307,18 @@ def wrapper(*args, **kwargs):
return wrapper


def compare_gates(g1: Gate, g2: Gate) -> bool:
union_mapping = [q.index for q in list(set(g1.get_qubit_operands()) | set(g2.get_qubit_operands()))]

from opensquirrel.circuit_matrix_calculator import get_circuit_matrix
from opensquirrel.reindexer import get_reindexed_circuit

matrix_g1 = get_circuit_matrix(get_reindexed_circuit([g1], union_mapping))
matrix_g2 = get_circuit_matrix(get_reindexed_circuit([g2], union_mapping))

return are_matrices_equivalent_up_to_global_phase(matrix_g1, matrix_g2)


@dataclass
class Comment(Statement):
str: str
Expand Down
7 changes: 3 additions & 4 deletions test/mapper/test_general_mapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from typing import List

import pytest

Expand Down Expand Up @@ -42,7 +43,7 @@ def circuit_fixture(self) -> Circuit:
return Circuit(register_manager, squirrel_ir)

@pytest.fixture(name="expected_statements")
def expected_statements_fixture(self) -> list[Statement]:
def expected_statements_fixture(self) -> List[Statement]:
return [
H(Qubit(1)),
CNOT(Qubit(1), Qubit(0)),
Expand All @@ -54,6 +55,4 @@ def expected_statements_fixture(self) -> list[Statement]:
def test_circuit_map(self, circuit: Circuit, expected_statements: list[Statement]) -> None:
mapper = HardcodedMapper(circuit.qubit_register_size, Mapping([1, 0, 2]))
circuit.map(mapper)

# Check that the circuit is altered as expected
assert circuit.register_manager.mapping == mapper.get_mapping()
assert circuit.squirrel_ir.statements == expected_statements
Loading