Skip to content

Commit

Permalink
Bugfix in Partition to support QGF type registers (#1448)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar authored Oct 8, 2024
1 parent 9d3f6b2 commit fbb031b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion qualtran/bloqs/bookkeeping/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']
def _classical_unpartition_to_bits(self, **vals: 'ClassicalValT') -> NDArray[np.uint8]:
out_vals: list[NDArray[np.uint8]] = []
for reg in self.regs:
reg_val = np.asarray(vals[reg.name])
reg_val = np.asanyarray(vals[reg.name])
bitstrings = reg.dtype.to_bits_array(reg_val.ravel())
out_vals.append(bitstrings.ravel())
return np.concatenate(out_vals)
Expand Down
13 changes: 12 additions & 1 deletion qualtran/bloqs/bookkeeping/partition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
from attrs import frozen

from qualtran import Bloq, BloqBuilder, QAny, Register, Signature, Soquet, SoquetT
from qualtran import Bloq, BloqBuilder, QAny, QGF, Register, Signature, Soquet, SoquetT
from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran.bloqs.basic_gates import CNOT
from qualtran.bloqs.bookkeeping import Partition
Expand Down Expand Up @@ -117,3 +117,14 @@ def test_partition_call_classically():
assert flat_out[2] == 2
out = bloq.adjoint().call_classically(**{reg.name: val for (reg, val) in zip(regs, out)})
assert out[0] == 64


def test_partition_call_classically_gf():
dtypes = [QGF(2, 2), QGF(2, 3)]
regs = (Register('xx', dtypes[0]), Register('yy', dtypes[1]))
partition = Partition(n=5, regs=regs)
unpartition = partition.adjoint()
for x in range(2**5):
xx, yy = partition.call_classically(x=x)
assert isinstance(xx, dtypes[0].gf_type) and isinstance(yy, dtypes[1].gf_type)
assert (x,) == unpartition.call_classically(xx=xx, yy=yy)

0 comments on commit fbb031b

Please sign in to comment.