diff --git a/doc/chromobius.pyi b/doc/chromobius.pyi index 12ad3ad..b4411e0 100644 --- a/doc/chromobius.pyi +++ b/doc/chromobius.pyi @@ -162,6 +162,86 @@ class CompiledDecoder: >>> mistakes = np.count_nonzero(differences) >>> assert mistakes < shots / 5 """ + @staticmethod + def predict_weighted_obs_flips_from_dets_bit_packed( + dets: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + """Predicts observable flips and weights from detection events. + + The returned weight comes directly from the underlying call to pymatching, not + accounting for the lifting process. + + Args: + dets: A bit packed numpy array of detection event data. The array can either + be 1-dimensional (a single shot to decode) or 2-dimensional (multiple + shots to decode, with the first axis being the shot axis and the second + axis being the detection event byte axis). + + The array's dtype must be np.uint8. If you have an array of dtype + np.bool_, you have data that's not bit packed. You can pack it by + using `np.packbits(array, bitorder='little')`. But ideally you + should attempt to never have unpacked data in the first place, + since it's 8x larger which can be a large performance loss. For + example, stim's sampler methods all have a `bit_packed=True` argument + that cause them to return bit packed data. + + Returns: + A tuple (obs, weights). + Obs is a bit packed numpy array of observable flip data. + Weights is a numpy array (or scalar) of floats. + + If dets is a 1D array, then the result has: + obs.shape = (math.ceil(num_obs / 8),) + obs.dtype = np.uint8 + weights.shape = () + weights.dtype = np.float32 + If dets is a 2D array, then the result has: + shape = (dets.shape[0], math.ceil(num_obs / 8),) + dtype = np.uint8 + weights.shape = (dets.shape[0],) + weights.dtype = np.float32 + + To determine if the observable with index k was flipped in shot s, compute: + `bool((obs[s, k // 8] >> (k % 8)) & 1)` + + Example: + >>> import stim + >>> import chromobius + >>> import numpy as np + + >>> repetition_color_code = stim.Circuit(''' + ... # Apply noise. + ... X_ERROR(0.1) 0 1 2 3 4 5 6 7 + ... # Measure three-body stabilizers to catch errors. + ... MPP Z0*Z1*Z2 Z1*Z2*Z3 Z2*Z3*Z4 Z3*Z4*Z5 Z4*Z5*Z6 Z5*Z6*Z7 + ... + ... # Annotate detectors, with a coloring in the 4th coordinate. + ... DETECTOR(0, 0, 0, 2) rec[-6] + ... DETECTOR(1, 0, 0, 0) rec[-5] + ... DETECTOR(2, 0, 0, 1) rec[-4] + ... DETECTOR(3, 0, 0, 2) rec[-3] + ... DETECTOR(4, 0, 0, 0) rec[-2] + ... DETECTOR(5, 0, 0, 1) rec[-1] + ... + ... # Check on the message. + ... M 0 + ... OBSERVABLE_INCLUDE(0) rec[-1] + ... ''') + + >>> # Sample the circuit. + >>> shots = 4096 + >>> sampler = repetition_color_code.compile_detector_sampler() + >>> dets, actual_obs_flips = sampler.sample( + ... shots=shots, + ... separate_observables=True, + ... bit_packed=True, + ... ) + + >>> # Decode with Chromobius. + >>> dem = repetition_color_code.detector_error_model() + >>> decoder = chromobius.compile_decoder_for_dem(dem) + >>> pred, weights = decoder.predict_obs_flips_from_dets_bit_packed(dets) + """ def compile_decoder_for_dem( dem: stim.DetectorErrorModel, ) -> chromobius.CompiledDecoder: diff --git a/doc/chromobius_api_reference.md b/doc/chromobius_api_reference.md index a6789b1..8018f42 100644 --- a/doc/chromobius_api_reference.md +++ b/doc/chromobius_api_reference.md @@ -7,6 +7,7 @@ - [`chromobius.CompiledDecoder`](#chromobius.CompiledDecoder) - [`chromobius.CompiledDecoder.from_dem`](#chromobius.CompiledDecoder.from_dem) - [`chromobius.CompiledDecoder.predict_obs_flips_from_dets_bit_packed`](#chromobius.CompiledDecoder.predict_obs_flips_from_dets_bit_packed) + - [`chromobius.CompiledDecoder.predict_weighted_obs_flips_from_dets_bit_packed`](#chromobius.CompiledDecoder.predict_weighted_obs_flips_from_dets_bit_packed) ```python # Types used by the method definitions. from typing import overload, TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union @@ -266,3 +267,90 @@ def predict_obs_flips_from_dets_bit_packed( >>> assert mistakes < shots / 5 """ ``` + + +```python +# chromobius.CompiledDecoder.predict_weighted_obs_flips_from_dets_bit_packed + +# (in class chromobius.CompiledDecoder) +@staticmethod +def predict_weighted_obs_flips_from_dets_bit_packed( + dets: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Predicts observable flips and weights from detection events. + + The returned weight comes directly from the underlying call to pymatching, not + accounting for the lifting process. + + Args: + dets: A bit packed numpy array of detection event data. The array can either + be 1-dimensional (a single shot to decode) or 2-dimensional (multiple + shots to decode, with the first axis being the shot axis and the second + axis being the detection event byte axis). + + The array's dtype must be np.uint8. If you have an array of dtype + np.bool_, you have data that's not bit packed. You can pack it by + using `np.packbits(array, bitorder='little')`. But ideally you + should attempt to never have unpacked data in the first place, + since it's 8x larger which can be a large performance loss. For + example, stim's sampler methods all have a `bit_packed=True` argument + that cause them to return bit packed data. + + Returns: + A tuple (obs, weights). + Obs is a bit packed numpy array of observable flip data. + Weights is a numpy array (or scalar) of floats. + + If dets is a 1D array, then the result has: + obs.shape = (math.ceil(num_obs / 8),) + obs.dtype = np.uint8 + weights.shape = () + weights.dtype = np.float32 + If dets is a 2D array, then the result has: + shape = (dets.shape[0], math.ceil(num_obs / 8),) + dtype = np.uint8 + weights.shape = (dets.shape[0],) + weights.dtype = np.float32 + + To determine if the observable with index k was flipped in shot s, compute: + `bool((obs[s, k // 8] >> (k % 8)) & 1)` + + Example: + >>> import stim + >>> import chromobius + >>> import numpy as np + + >>> repetition_color_code = stim.Circuit(''' + ... # Apply noise. + ... X_ERROR(0.1) 0 1 2 3 4 5 6 7 + ... # Measure three-body stabilizers to catch errors. + ... MPP Z0*Z1*Z2 Z1*Z2*Z3 Z2*Z3*Z4 Z3*Z4*Z5 Z4*Z5*Z6 Z5*Z6*Z7 + ... + ... # Annotate detectors, with a coloring in the 4th coordinate. + ... DETECTOR(0, 0, 0, 2) rec[-6] + ... DETECTOR(1, 0, 0, 0) rec[-5] + ... DETECTOR(2, 0, 0, 1) rec[-4] + ... DETECTOR(3, 0, 0, 2) rec[-3] + ... DETECTOR(4, 0, 0, 0) rec[-2] + ... DETECTOR(5, 0, 0, 1) rec[-1] + ... + ... # Check on the message. + ... M 0 + ... OBSERVABLE_INCLUDE(0) rec[-1] + ... ''') + + >>> # Sample the circuit. + >>> shots = 4096 + >>> sampler = repetition_color_code.compile_detector_sampler() + >>> dets, actual_obs_flips = sampler.sample( + ... shots=shots, + ... separate_observables=True, + ... bit_packed=True, + ... ) + + >>> # Decode with Chromobius. + >>> dem = repetition_color_code.detector_error_model() + >>> decoder = chromobius.compile_decoder_for_dem(dem) + >>> pred, weights = decoder.predict_obs_flips_from_dets_bit_packed(dets) + """ +```