Skip to content

Commit

Permalink
regen docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc committed Sep 24, 2024
1 parent 1cd2d58 commit b2deab3
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 0 deletions.
80 changes: 80 additions & 0 deletions doc/chromobius.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,86 @@ class CompiledDecoder:
>>> mistakes = np.count_nonzero(differences)
>>> assert mistakes < shots / 5
"""
@staticmethod
def predict_weighted_obs_flips_from_dets_bit_packed(
dets: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Predicts observable flips and weights from detection events.
The returned weight comes directly from the underlying call to pymatching, not
accounting for the lifting process.
Args:
dets: A bit packed numpy array of detection event data. The array can either
be 1-dimensional (a single shot to decode) or 2-dimensional (multiple
shots to decode, with the first axis being the shot axis and the second
axis being the detection event byte axis).
The array's dtype must be np.uint8. If you have an array of dtype
np.bool_, you have data that's not bit packed. You can pack it by
using `np.packbits(array, bitorder='little')`. But ideally you
should attempt to never have unpacked data in the first place,
since it's 8x larger which can be a large performance loss. For
example, stim's sampler methods all have a `bit_packed=True` argument
that cause them to return bit packed data.
Returns:
A tuple (obs, weights).
Obs is a bit packed numpy array of observable flip data.
Weights is a numpy array (or scalar) of floats.
If dets is a 1D array, then the result has:
obs.shape = (math.ceil(num_obs / 8),)
obs.dtype = np.uint8
weights.shape = ()
weights.dtype = np.float32
If dets is a 2D array, then the result has:
shape = (dets.shape[0], math.ceil(num_obs / 8),)
dtype = np.uint8
weights.shape = (dets.shape[0],)
weights.dtype = np.float32
To determine if the observable with index k was flipped in shot s, compute:
`bool((obs[s, k // 8] >> (k % 8)) & 1)`
Example:
>>> import stim
>>> import chromobius
>>> import numpy as np
>>> repetition_color_code = stim.Circuit('''
... # Apply noise.
... X_ERROR(0.1) 0 1 2 3 4 5 6 7
... # Measure three-body stabilizers to catch errors.
... MPP Z0*Z1*Z2 Z1*Z2*Z3 Z2*Z3*Z4 Z3*Z4*Z5 Z4*Z5*Z6 Z5*Z6*Z7
...
... # Annotate detectors, with a coloring in the 4th coordinate.
... DETECTOR(0, 0, 0, 2) rec[-6]
... DETECTOR(1, 0, 0, 0) rec[-5]
... DETECTOR(2, 0, 0, 1) rec[-4]
... DETECTOR(3, 0, 0, 2) rec[-3]
... DETECTOR(4, 0, 0, 0) rec[-2]
... DETECTOR(5, 0, 0, 1) rec[-1]
...
... # Check on the message.
... M 0
... OBSERVABLE_INCLUDE(0) rec[-1]
... ''')
>>> # Sample the circuit.
>>> shots = 4096
>>> sampler = repetition_color_code.compile_detector_sampler()
>>> dets, actual_obs_flips = sampler.sample(
... shots=shots,
... separate_observables=True,
... bit_packed=True,
... )
>>> # Decode with Chromobius.
>>> dem = repetition_color_code.detector_error_model()
>>> decoder = chromobius.compile_decoder_for_dem(dem)
>>> pred, weights = decoder.predict_obs_flips_from_dets_bit_packed(dets)
"""
def compile_decoder_for_dem(
dem: stim.DetectorErrorModel,
) -> chromobius.CompiledDecoder:
Expand Down
88 changes: 88 additions & 0 deletions doc/chromobius_api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- [`chromobius.CompiledDecoder`](#chromobius.CompiledDecoder)
- [`chromobius.CompiledDecoder.from_dem`](#chromobius.CompiledDecoder.from_dem)
- [`chromobius.CompiledDecoder.predict_obs_flips_from_dets_bit_packed`](#chromobius.CompiledDecoder.predict_obs_flips_from_dets_bit_packed)
- [`chromobius.CompiledDecoder.predict_weighted_obs_flips_from_dets_bit_packed`](#chromobius.CompiledDecoder.predict_weighted_obs_flips_from_dets_bit_packed)
```python
# Types used by the method definitions.
from typing import overload, TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -266,3 +267,90 @@ def predict_obs_flips_from_dets_bit_packed(
>>> assert mistakes < shots / 5
"""
```

<a name="chromobius.CompiledDecoder.predict_weighted_obs_flips_from_dets_bit_packed"></a>
```python
# chromobius.CompiledDecoder.predict_weighted_obs_flips_from_dets_bit_packed

# (in class chromobius.CompiledDecoder)
@staticmethod
def predict_weighted_obs_flips_from_dets_bit_packed(
dets: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Predicts observable flips and weights from detection events.
The returned weight comes directly from the underlying call to pymatching, not
accounting for the lifting process.
Args:
dets: A bit packed numpy array of detection event data. The array can either
be 1-dimensional (a single shot to decode) or 2-dimensional (multiple
shots to decode, with the first axis being the shot axis and the second
axis being the detection event byte axis).
The array's dtype must be np.uint8. If you have an array of dtype
np.bool_, you have data that's not bit packed. You can pack it by
using `np.packbits(array, bitorder='little')`. But ideally you
should attempt to never have unpacked data in the first place,
since it's 8x larger which can be a large performance loss. For
example, stim's sampler methods all have a `bit_packed=True` argument
that cause them to return bit packed data.
Returns:
A tuple (obs, weights).
Obs is a bit packed numpy array of observable flip data.
Weights is a numpy array (or scalar) of floats.
If dets is a 1D array, then the result has:
obs.shape = (math.ceil(num_obs / 8),)
obs.dtype = np.uint8
weights.shape = ()
weights.dtype = np.float32
If dets is a 2D array, then the result has:
shape = (dets.shape[0], math.ceil(num_obs / 8),)
dtype = np.uint8
weights.shape = (dets.shape[0],)
weights.dtype = np.float32
To determine if the observable with index k was flipped in shot s, compute:
`bool((obs[s, k // 8] >> (k % 8)) & 1)`
Example:
>>> import stim
>>> import chromobius
>>> import numpy as np
>>> repetition_color_code = stim.Circuit('''
... # Apply noise.
... X_ERROR(0.1) 0 1 2 3 4 5 6 7
... # Measure three-body stabilizers to catch errors.
... MPP Z0*Z1*Z2 Z1*Z2*Z3 Z2*Z3*Z4 Z3*Z4*Z5 Z4*Z5*Z6 Z5*Z6*Z7
...
... # Annotate detectors, with a coloring in the 4th coordinate.
... DETECTOR(0, 0, 0, 2) rec[-6]
... DETECTOR(1, 0, 0, 0) rec[-5]
... DETECTOR(2, 0, 0, 1) rec[-4]
... DETECTOR(3, 0, 0, 2) rec[-3]
... DETECTOR(4, 0, 0, 0) rec[-2]
... DETECTOR(5, 0, 0, 1) rec[-1]
...
... # Check on the message.
... M 0
... OBSERVABLE_INCLUDE(0) rec[-1]
... ''')
>>> # Sample the circuit.
>>> shots = 4096
>>> sampler = repetition_color_code.compile_detector_sampler()
>>> dets, actual_obs_flips = sampler.sample(
... shots=shots,
... separate_observables=True,
... bit_packed=True,
... )
>>> # Decode with Chromobius.
>>> dem = repetition_color_code.detector_error_model()
>>> decoder = chromobius.compile_decoder_for_dem(dem)
>>> pred, weights = decoder.predict_obs_flips_from_dets_bit_packed(dets)
"""
```

0 comments on commit b2deab3

Please sign in to comment.