diff --git a/src/orquestra/quantum/circuits/__init__.py b/src/orquestra/quantum/circuits/__init__.py index 390f1db..3d6175c 100644 --- a/src/orquestra/quantum/circuits/__init__.py +++ b/src/orquestra/quantum/circuits/__init__.py @@ -188,6 +188,7 @@ create_layer_of_gates, ) from ._itertools import ( + combine_bitstrings, combine_measurement_counts, expand_sample_sizes, split_into_batches, diff --git a/src/orquestra/quantum/circuits/_itertools.py b/src/orquestra/quantum/circuits/_itertools.py index 3b1ef75..33c7a15 100644 --- a/src/orquestra/quantum/circuits/_itertools.py +++ b/src/orquestra/quantum/circuits/_itertools.py @@ -2,7 +2,7 @@ from functools import reduce from itertools import islice from math import ceil -from typing import Dict, Iterable, Sequence, Tuple, TypeVar +from typing import Dict, Iterable, List, Sequence, Tuple, TypeVar T = TypeVar("T") @@ -140,7 +140,7 @@ def _combine_measurements( def combine_measurement_counts( all_measurements: Sequence[Dict[str, int]], multiplicities: Sequence[int] -) -> Sequence[Dict[str, int]]: +) -> List[Dict[str, int]]: """Combine (aggregate) measurements of the same circuits run several times. Suppose multiplicities is a list [1, 2 ,3]. Then, the all_measurements should @@ -165,7 +165,7 @@ def combine_measurement_counts( Returns: Sequence of combined measurements of length equal len(multiplicities) Raises: - ValueError: if len(all_measurements != sum(multiplicities) + ValueError: if len(all_measurements) != sum(multiplicities) """ if len(all_measurements) != (sum_multiplicities := sum(multiplicities)): raise ValueError( @@ -178,3 +178,48 @@ def combine_measurement_counts( reduce(_combine_measurements, islice(measurements_it, multiplicity)) for multiplicity in multiplicities ] + + +def combine_bitstrings( + all_bitstrings: Sequence[List[str]], multiplicities: Sequence[int] +) -> List[List[str]]: + """Combine (aggregate) bitstrings of the same circuits run several times. + + Suppose multiplicities is a list [1, 2 ,3]. Then, the all_bitstrings should + be a sequence of 1+2+3=6 elements, each of them being a sequence of bitstrings. + For instance, all_bitstrings could be equal to: + [["00", "01"], ["1", "0"], ["0"], ["001", "001"], ["111", "000"], ["000"]] + and then the result would be: + [ + ["00", "01"], + ["1", "0", "0], + ["001", "001", "111", "000", "000"] + ] + + Args: + all_bitstrings: sequence of lists containing bitstrings + gathered from some, possibly duplicated, circuits. The bitstrings + lists objects corresponding to the same circuit should be placed next + to each other. Should have the same length as sum(multiplicities). + multiplicities: sequence of positive integers marking groups of + consecutive measurements corresponding to the same circuit. For + instance, multiplicities [1, 2, 3] mean that first group of + bitstrings comprises 1 sequence, second group comprises 2 + consecutive sequences, third group contains 3 consecutive + sequences and so on. + Returns: + Sequence of combined measurements of length equal len(multiplicities) + Raises: + ValueError: if len(all_bitstrings) != sum(multiplicities) + """ + if len(all_bitstrings) != (sum_multiplicities := sum(multiplicities)): + raise ValueError( + "Mismatch between multiplicities and number of measurements to combine. " + f"Got {len(all_bitstrings)} bitstrings lists objects to combine " + f"but multiplicities sum to {sum_multiplicities}" + ) + bitstrings_it = iter(all_bitstrings) + return [ + sum(islice(bitstrings_it, multiplicity), start=[]) + for multiplicity in multiplicities + ] diff --git a/tests/orquestra/quantum/circuits/_itertools_test.py b/tests/orquestra/quantum/circuits/_itertools_test.py index 5d6a299..71ba083 100644 --- a/tests/orquestra/quantum/circuits/_itertools_test.py +++ b/tests/orquestra/quantum/circuits/_itertools_test.py @@ -3,6 +3,7 @@ from orquestra.quantum.circuits import CNOT, H, X from orquestra.quantum.circuits._circuit import Circuit from orquestra.quantum.circuits._itertools import ( + combine_bitstrings, combine_measurement_counts, expand_sample_sizes, split_into_batches, @@ -139,3 +140,50 @@ def test_counts_of_combined_measurements_are_correct( combined_counts, ) ) + + +class TestCombiningBitstrings: + def test_raises_error_when_multiplicities_dont_match_measurements(self): + multiplicities = [1, 2, 3, 2, 2] + # Clearly a mismatch, we should have 10 bitstring sequences + measurements = [["00", "11"] for _ in range(5)] + + with pytest.raises(ValueError): + combine_measurement_counts(measurements, multiplicities) + + @pytest.mark.parametrize( + "all_bitstrings, multiplicities, combined_bitstrings", + [ + ([["00", "11"]], [1], [["00", "11"]]), + ( + [ + ["00", "11"], + ["01", "00"], + ["00", "00"], + ["000"], + ["111", "001"], + ["00000"], + ["0", "0", "0"], + ["1", "0"], + ["0", "0"], + ], + [3, 2, 1, 3], + [ + ["00", "11", "01", "00", "00", "00"], + ["000", "111", "001"], + ["00000"], + ["0", "0", "0", "1", "0", "0", "0"], + ], + ), + ], + ) + def test_counts_of_combined_bitstrings_are_correct( + self, all_bitstrings, multiplicities, combined_bitstrings + ): + assert all( + actual == expected + for actual, expected in zip( + combine_bitstrings(all_bitstrings, multiplicities), + combined_bitstrings, + ) + )