Skip to content

Commit

Permalink
Merge pull request #79 from zapatacomputing/kj/zqs-1279/add-function-…
Browse files Browse the repository at this point in the history
…to-combine-bitstrings

feat: add function to combine bitstrings
  • Loading branch information
Athena Caesura authored Feb 3, 2023
2 parents 3b873d0 + 2194e34 commit 10fae52
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/orquestra/quantum/circuits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
create_layer_of_gates,
)
from ._itertools import (
combine_bitstrings,
combine_measurement_counts,
expand_sample_sizes,
split_into_batches,
Expand Down
51 changes: 48 additions & 3 deletions src/orquestra/quantum/circuits/_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import reduce
from itertools import islice
from math import ceil
from typing import Dict, Iterable, Sequence, Tuple, TypeVar
from typing import Dict, Iterable, List, Sequence, Tuple, TypeVar

T = TypeVar("T")

Expand Down Expand Up @@ -140,7 +140,7 @@ def _combine_measurements(

def combine_measurement_counts(
all_measurements: Sequence[Dict[str, int]], multiplicities: Sequence[int]
) -> Sequence[Dict[str, int]]:
) -> List[Dict[str, int]]:
"""Combine (aggregate) measurements of the same circuits run several times.
Suppose multiplicities is a list [1, 2 ,3]. Then, the all_measurements should
Expand All @@ -165,7 +165,7 @@ def combine_measurement_counts(
Returns:
Sequence of combined measurements of length equal len(multiplicities)
Raises:
ValueError: if len(all_measurements != sum(multiplicities)
ValueError: if len(all_measurements) != sum(multiplicities)
"""
if len(all_measurements) != (sum_multiplicities := sum(multiplicities)):
raise ValueError(
Expand All @@ -178,3 +178,48 @@ def combine_measurement_counts(
reduce(_combine_measurements, islice(measurements_it, multiplicity))
for multiplicity in multiplicities
]


def combine_bitstrings(
all_bitstrings: Sequence[List[str]], multiplicities: Sequence[int]
) -> List[List[str]]:
"""Combine (aggregate) bitstrings of the same circuits run several times.
Suppose multiplicities is a list [1, 2 ,3]. Then, the all_bitstrings should
be a sequence of 1+2+3=6 elements, each of them being a sequence of bitstrings.
For instance, all_bitstrings could be equal to:
[["00", "01"], ["1", "0"], ["0"], ["001", "001"], ["111", "000"], ["000"]]
and then the result would be:
[
["00", "01"],
["1", "0", "0],
["001", "001", "111", "000", "000"]
]
Args:
all_bitstrings: sequence of lists containing bitstrings
gathered from some, possibly duplicated, circuits. The bitstrings
lists objects corresponding to the same circuit should be placed next
to each other. Should have the same length as sum(multiplicities).
multiplicities: sequence of positive integers marking groups of
consecutive measurements corresponding to the same circuit. For
instance, multiplicities [1, 2, 3] mean that first group of
bitstrings comprises 1 sequence, second group comprises 2
consecutive sequences, third group contains 3 consecutive
sequences and so on.
Returns:
Sequence of combined measurements of length equal len(multiplicities)
Raises:
ValueError: if len(all_bitstrings) != sum(multiplicities)
"""
if len(all_bitstrings) != (sum_multiplicities := sum(multiplicities)):
raise ValueError(
"Mismatch between multiplicities and number of measurements to combine. "
f"Got {len(all_bitstrings)} bitstrings lists objects to combine "
f"but multiplicities sum to {sum_multiplicities}"
)
bitstrings_it = iter(all_bitstrings)
return [
sum(islice(bitstrings_it, multiplicity), start=[])
for multiplicity in multiplicities
]
48 changes: 48 additions & 0 deletions tests/orquestra/quantum/circuits/_itertools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from orquestra.quantum.circuits import CNOT, H, X
from orquestra.quantum.circuits._circuit import Circuit
from orquestra.quantum.circuits._itertools import (
combine_bitstrings,
combine_measurement_counts,
expand_sample_sizes,
split_into_batches,
Expand Down Expand Up @@ -139,3 +140,50 @@ def test_counts_of_combined_measurements_are_correct(
combined_counts,
)
)


class TestCombiningBitstrings:
def test_raises_error_when_multiplicities_dont_match_measurements(self):
multiplicities = [1, 2, 3, 2, 2]
# Clearly a mismatch, we should have 10 bitstring sequences
measurements = [["00", "11"] for _ in range(5)]

with pytest.raises(ValueError):
combine_measurement_counts(measurements, multiplicities)

@pytest.mark.parametrize(
"all_bitstrings, multiplicities, combined_bitstrings",
[
([["00", "11"]], [1], [["00", "11"]]),
(
[
["00", "11"],
["01", "00"],
["00", "00"],
["000"],
["111", "001"],
["00000"],
["0", "0", "0"],
["1", "0"],
["0", "0"],
],
[3, 2, 1, 3],
[
["00", "11", "01", "00", "00", "00"],
["000", "111", "001"],
["00000"],
["0", "0", "0", "1", "0", "0", "0"],
],
),
],
)
def test_counts_of_combined_bitstrings_are_correct(
self, all_bitstrings, multiplicities, combined_bitstrings
):
assert all(
actual == expected
for actual, expected in zip(
combine_bitstrings(all_bitstrings, multiplicities),
combined_bitstrings,
)
)

0 comments on commit 10fae52

Please sign in to comment.