Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ __pycache__/
# C extensions
*.so

# AI
.claude

# IDE
.vscode

Expand Down
94 changes: 60 additions & 34 deletions samplomatic/pre_samplex/pre_samplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,11 +1336,14 @@ def add_combine_node(
register_names: dict[NodeIndex, dict[NodeIndex, RegisterName]],
combined_register_name: str,
combined_register_type: VirtualType,
) -> NodeIndex:
) -> tuple[NodeIndex, RegisterName]:
"""Add a node that combines all the predecessor nodes of a given pre-node.

This function adds a :class:`~.SliceRegisterNode` if the given pre-node has a single
predecessor, or a :class:`~.CombineRegistersNode` if it has multiple predecessors.
If the pre-node has a single predecessor and the slice is trivial (same type, identity
index mapping, no forced copy), the slice node is skipped entirely and the predecessor's
node index and register name are returned directly.

Args:
samplex: The samplex to add nodes to.
Expand All @@ -1356,7 +1359,7 @@ def add_combine_node(
combined_register_type: The type of register to combine the predecessor registers into.

Returns:
A tuple containing the combine node's index and the new register name.
A tuple containing the node index and register name to use for downstream nodes.
"""
pred_idxs = self.sorted_predecessor_idxs(pre_node_idx, order)
pre_edges = [self.graph.get_edge_data(pred_idx, pre_node_idx) for pred_idx in pred_idxs]
Expand All @@ -1379,6 +1382,18 @@ def add_combine_node(
input_register_name, (source_idxs, destination_idxs, input_type) = next(
iter(operands.items())
)

# Skip trivial slices: same type, identity index mapping, no forced copy,
# and the predecessor's register has the same number of subsystems.
if (
input_type == combined_register_type
and not pre_edge.force_register_copy
and np.array_equal(source_idxs, destination_idxs)
and len(destination_idxs) == len(subsystems)
and len(source_idxs) == len(self.graph[pred_idxs[0]].subsystems)
):
return pre_nodes_to_nodes[pred_idxs[0]], input_register_name

slice_idxs = np.empty(len(destination_idxs))
slice_idxs[destination_idxs] = source_idxs
combine_node = SliceRegisterNode(
Expand All @@ -1400,7 +1415,7 @@ def add_combine_node(

for pred_idx in pred_idxs:
samplex.add_edge(pre_nodes_to_nodes[pred_idx], combine_node_idx)
return combine_node_idx
return combine_node_idx, combined_register_name

def add_propagate_node(
self,
Expand Down Expand Up @@ -1434,19 +1449,51 @@ def add_propagate_node(
incoming = set()
for predecessor_idx in self.graph.predecessor_indices(pre_propagate_idx):
incoming.add(samplex.graph[pre_nodes_to_nodes[predecessor_idx]].outgoing_register_type)

# Determine the combined register type.
if mode is InstructionMode.MULTIPLY and pre_propagate.operation.num_qubits == 1:
combined_register_type = VirtualType.U2
elif mode is InstructionMode.PROPAGATE and incoming == {VirtualType.PAULI}:
if (
op_name in PAULI_PAST_CLIFFORD_INVARIANTS
or op_name in PAULI_PAST_CLIFFORD_LOOKUP_TABLES
):
combined_register_type = VirtualType.PAULI
else:
raise SamplexBuildError(
f"Encountered unsupported {op_name} propagation with mode {mode} and "
f"incoming virtual gates {incoming}."
)
else:
raise SamplexBuildError(
f"Encountered unsupported {op_name} propagation with mode {mode} and "
f"incoming virtual gates {incoming}."
)

# Add combine/slice node (may be skipped for trivial slices).
combine_node_idx, actual_register_name = self.add_combine_node(
samplex,
pre_propagate_idx,
pre_nodes_to_nodes,
order,
register_names,
combined_register_name,
combined_register_type,
)

# Create the propagation node using the actual register name.
if mode is InstructionMode.MULTIPLY and pre_propagate.operation.num_qubits == 1:
if pre_propagate.operation.is_parameterized():
param_idxs = [
samplex.append_parameter_expression(param) for _, param in pre_propagate.params
]
if pre_propagate.direction is Direction.LEFT:
propagate_node = RightU2ParametricMultiplicationNode(
op_name, combined_register_name, param_idxs
op_name, actual_register_name, param_idxs
)
else:
propagate_node = LeftU2ParametricMultiplicationNode(
op_name, combined_register_name, param_idxs
op_name, actual_register_name, param_idxs
)
else:
if op_name in SUPPORTED_1Q_FRACTIONAL_GATES:
Expand All @@ -1456,57 +1503,36 @@ def add_propagate_node(
else:
register = U2Register(np.array(pre_propagate.operation).reshape(1, 1, 2, 2))
if pre_propagate.direction is Direction.LEFT:
propagate_node = RightMultiplicationNode(register, combined_register_name)
propagate_node = RightMultiplicationNode(register, actual_register_name)
else:
propagate_node = LeftMultiplicationNode(register, combined_register_name)
propagate_node = LeftMultiplicationNode(register, actual_register_name)
elif (
mode is InstructionMode.PROPAGATE
and incoming == {VirtualType.PAULI}
and op_name in PAULI_PAST_CLIFFORD_INVARIANTS
):
# No node is needed since this is an invariant, but we do need to track
# mappings of register names and nodes. This will be done later.
combined_register_type = VirtualType.PAULI
propagate_node = None
elif (
mode is InstructionMode.PROPAGATE
and incoming == {VirtualType.PAULI}
and op_name in PAULI_PAST_CLIFFORD_LOOKUP_TABLES
):
combined_register_type = VirtualType.PAULI
propagate_node = PauliPastCliffordNode(
op_name,
combined_register_name,
actual_register_name,
np.array(list(pre_propagate.partition), dtype=np.intp),
)
else:
raise SamplexBuildError(
f"Encountered unsupported {op_name} propagation with mode {mode} and "
f"incoming virtual gates {incoming}."
)

combine_node_idx = self.add_combine_node(
samplex,
pre_propagate_idx,
pre_nodes_to_nodes,
order,
register_names,
combined_register_name,
combined_register_type,
)

if propagate_node is not None:
node_idx = samplex.add_node(propagate_node)
samplex.add_edge(combine_node_idx, node_idx)
else:
# TODO: It should be possible to not add a slice node in this case, if there is
# a single predecessor.
node_idx = combine_node_idx

pre_nodes_to_nodes[pre_propagate_idx] = node_idx

for pre_successor_idx in self.graph.successor_indices(pre_propagate_idx):
register_names[pre_successor_idx][pre_propagate_idx] = combined_register_name
register_names[pre_successor_idx][pre_propagate_idx] = actual_register_name

def add_collect_node(
self,
Expand All @@ -1531,7 +1557,7 @@ def add_collect_node(
pre_node = cast(PreCollect, self.graph[pre_node_idx])
all_subsystems = pre_node.subsystems
combined_name = f"collect_{order[pre_node_idx]}"
combine_node_idx = self.add_combine_node(
combine_node_idx, actual_register_name = self.add_combine_node(
samplex,
pre_node_idx,
pre_nodes_to_nodes,
Expand All @@ -1544,7 +1570,7 @@ def add_collect_node(
collect = CollectTemplateValues(
"parameter_values",
pre_node.param_idxs,
combined_name,
actual_register_name,
VirtualType.U2,
np.arange(len(all_subsystems)),
pre_node.synth,
Expand Down Expand Up @@ -1574,7 +1600,7 @@ def add_collect_z2_to_output_node(
"""
pre_node = cast(PreZ2Collect, self.graph[pre_node_idx])
combined_name = f"z2_collect_{order[pre_node_idx]}"
combine_node_idx = self.add_combine_node(
combine_node_idx, actual_register_name = self.add_combine_node(
samplex,
pre_node_idx,
pre_nodes_to_nodes,
Expand All @@ -1586,7 +1612,7 @@ def add_collect_z2_to_output_node(

for reg_name, clbit_idxs in pre_node.clbit_idxs.items():
z2collect = CollectZ2ToOutputNode(
combined_name,
actual_register_name,
np.array(pre_node.subsystems_idxs[reg_name]),
f"measurement_flips.{reg_name}",
clbit_idxs,
Expand Down
5 changes: 1 addition & 4 deletions samplomatic/samplex/nodes/inject_noise_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This code is a Qiskit project.
#
# (C) Copyright IBM 2025.
# (C) Copyright IBM 2025-2026.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
Expand All @@ -12,9 +12,6 @@

"""InjectNoiseNode"""

import numpy as np
from qiskit.quantum_info import PauliLindbladMap

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to this PR, not sure how it got past pre-commit

from ...aliases import NumSubsystems, RegisterName, StrRef
from ...virtual_registers import PauliRegister, VirtualType, Z2Register
from .sampling_node import SamplingNode
Expand Down
32 changes: 31 additions & 1 deletion test/performance/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This code is a Qiskit project.
#
# (C) Copyright IBM 2025.
# (C) Copyright IBM 2025-2026.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
Expand Down Expand Up @@ -76,3 +76,33 @@ def test_deserialize_noisy_circuit(rng, benchmark, num_qubits, num_gates):
_, samplex = build(circuit)
samplex_json = samplex_to_json(samplex, None)
benchmark(samplex_from_json, samplex_json)


@pytest.mark.parametrize(
("num_qubits", "num_gates"),
[
pytest.param(
100,
5_000,
marks=pytest.mark.skipif(
"config.getoption('--performance-light')", reason="smoke test only"
),
),
pytest.param(
10,
100,
marks=pytest.mark.skipif(
"not config.getoption('--performance-light')", reason="performance test only"
),
),
],
)
def test_serialized_size(rng, benchmark, num_qubits, num_gates):
"""Measure the serialized JSON size of a samplex."""
num_boxes = num_gates // (num_qubits // 2)
circuit = make_layered_circuit(num_qubits, num_boxes, inject_noise=True)

_, samplex = build(circuit)
samplex_json = benchmark(samplex_to_json, samplex)
benchmark.extra_info["serialized_bytes"] = len(samplex_json.encode())
benchmark.extra_info["serialized_kb"] = len(samplex_json.encode()) / 1024