Skip to content

Commit c9975ef

Browse files
committed
SoquetT is a protocol
1 parent 1fdee48 commit c9975ef

File tree

18 files changed

+138
-84
lines changed

18 files changed

+138
-84
lines changed

qualtran/_infra/composite_bloq.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626
Mapping,
2727
Optional,
2828
overload,
29+
Protocol,
2930
Sequence,
3031
Set,
3132
Tuple,
3233
TYPE_CHECKING,
34+
TypeGuard,
3335
TypeVar,
3436
Union,
3537
)
@@ -55,13 +57,15 @@
5557
from qualtran.simulation.classical_sim import ClassicalValT
5658
from qualtran.symbolics import SymbolicInt
5759

58-
# NDArrays must be bound to np.generic
59-
_SoquetType = TypeVar('_SoquetType', bound=np.generic)
6060

61-
SoquetT = Union[Soquet, NDArray[_SoquetType]]
62-
"""A `Soquet` or array of soquets."""
61+
class SoquetT(Protocol):
62+
@property
63+
def shape(self) -> Tuple[int, ...]: ...
6364

64-
SoquetInT = Union[Soquet, NDArray[_SoquetType], Sequence[Soquet]]
65+
def item(self, *args) -> Soquet: ...
66+
67+
68+
SoquetInT = Union[SoquetT, Sequence[SoquetT]]
6569
"""A soquet or array-like of soquets.
6670
6771
This type alias is used for input argument to parts of the library that are more
@@ -693,9 +697,10 @@ def _flatten_soquet_collection(vals: Iterable[SoquetT]) -> List[Soquet]:
693697
"""
694698
soqvals = []
695699
for soq_or_arr in vals:
696-
if isinstance(soq_or_arr, Soquet):
697-
soqvals.append(soq_or_arr)
700+
if BloqBuilder.is_single(soq_or_arr):
701+
soqvals.append(soq_or_arr.item())
698702
else:
703+
assert BloqBuilder.is_ndarray(soq_or_arr)
699704
soqvals.extend(soq_or_arr.reshape(-1))
700705
return soqvals
701706

@@ -802,13 +807,10 @@ def _process_soquets(
802807
unchecked_names.remove(reg.name) # so we can check for surplus arguments.
803808

804809
for li in reg.all_idxs():
805-
idxed_soq = in_soq[li]
806-
assert isinstance(idxed_soq, Soquet), idxed_soq
810+
idxed_soq = in_soq[li].item()
807811
func(idxed_soq, reg, li)
808-
if not check_dtypes_consistent(idxed_soq.reg.dtype, reg.dtype):
809-
extra_str = (
810-
f"{idxed_soq.reg.name}: {idxed_soq.reg.dtype} vs {reg.name}: {reg.dtype}"
811-
)
812+
if not check_dtypes_consistent(idxed_soq.dtype, reg.dtype):
813+
extra_str = f"{idxed_soq.reg.name}: {idxed_soq.dtype} vs {reg.name}: {reg.dtype}"
812814
raise BloqError(
813815
f"{debug_str} register dtypes are not consistent {extra_str}."
814816
) from None
@@ -838,9 +840,9 @@ def _map_soqs(
838840
# First: flatten out any numpy arrays
839841
flat_soq_map: Dict[Soquet, Soquet] = {}
840842
for old_soqs, new_soqs in soq_map:
841-
if isinstance(old_soqs, Soquet):
842-
assert isinstance(new_soqs, Soquet), new_soqs
843-
flat_soq_map[old_soqs] = new_soqs
843+
if BloqBuilder.is_single(old_soqs):
844+
assert BloqBuilder.is_single(new_soqs), new_soqs
845+
flat_soq_map[old_soqs] = new_soqs.item()
844846
continue
845847

846848
assert isinstance(old_soqs, np.ndarray), old_soqs
@@ -858,9 +860,9 @@ def _map_soq(soq: Soquet) -> Soquet:
858860
vmap = np.vectorize(_map_soq, otypes=[object])
859861

860862
def _map_soqs(soqs: SoquetT) -> SoquetT:
861-
if isinstance(soqs, Soquet):
862-
return _map_soq(soqs)
863-
return vmap(soqs)
863+
if BloqBuilder.is_ndarray(soqs):
864+
return vmap(soqs)
865+
return _map_soq(soqs.item())
864866

865867
return {name: _map_soqs(soqs) for name, soqs in soqs.items()}
866868

@@ -1061,6 +1063,24 @@ def from_signature(
10611063

10621064
return bb, initial_soqs
10631065

1066+
@staticmethod
1067+
def is_single(x: 'SoquetT') -> TypeGuard['Soquet']:
1068+
"""Returns True if `x` is a single soquet (not an ndarray of them).
1069+
1070+
This doesn't use stringent runtime type checking; it uses the SoquetT protocol
1071+
for "duck typing".
1072+
"""
1073+
return x.shape == ()
1074+
1075+
@staticmethod
1076+
def is_ndarray(x: 'SoquetT') -> TypeGuard['NDArray']:
1077+
"""Returns True if `x` is an ndarray of soquets (not a single one).
1078+
1079+
This doesn't use stringent runtime type checking; it uses the SoquetT protocol
1080+
for "duck typing".
1081+
"""
1082+
return x.shape != ()
1083+
10641084
@staticmethod
10651085
def map_soqs(
10661086
soqs: Dict[str, SoquetT], soq_map: Iterable[Tuple[SoquetT, SoquetT]]
@@ -1265,8 +1285,7 @@ def add_from(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[SoquetT, ...]:
12651285
cbloq = bloq.decompose_bloq()
12661286

12671287
for k, v in in_soqs.items():
1268-
if not isinstance(v, Soquet):
1269-
in_soqs[k] = np.asarray(v)
1288+
in_soqs[k] = np.asarray(v)
12701289

12711290
# Initial mapping of LeftDangle according to user-provided in_soqs.
12721291
soq_map: List[Tuple[SoquetT, SoquetT]] = [
@@ -1306,12 +1325,13 @@ def finalize(self, **final_soqs: SoquetT) -> CompositeBloq:
13061325

13071326
def _infer_reg(name: str, soq: SoquetT) -> Register:
13081327
"""Go from Soquet -> register, but use a specific name for the register."""
1309-
if isinstance(soq, Soquet):
1310-
return Register(name=name, dtype=soq.reg.dtype, side=Side.RIGHT)
1328+
if BloqBuilder.is_single(soq):
1329+
return Register(name=name, dtype=soq.dtype, side=Side.RIGHT)
1330+
assert BloqBuilder.is_ndarray(soq)
13111331

13121332
# Get info from 0th soquet in an ndarray.
13131333
return Register(
1314-
name=name, dtype=soq.reshape(-1)[0].reg.dtype, shape=soq.shape, side=Side.RIGHT
1334+
name=name, dtype=soq.reshape(-1).item(0).dtype, shape=soq.shape, side=Side.RIGHT
13151335
)
13161336

13171337
right_reg_names = [reg.name for reg in self._regs if reg.side & Side.RIGHT]
@@ -1358,10 +1378,10 @@ def allocate(
13581378
def free(self, soq: Soquet, dirty: bool = False) -> None:
13591379
from qualtran.bloqs.bookkeeping import Free
13601380

1361-
if not isinstance(soq, Soquet):
1381+
if not BloqBuilder.is_single(soq):
13621382
raise ValueError("`free` expects a single Soquet to free.")
13631383

1364-
qdtype = soq.reg.dtype
1384+
qdtype = soq.dtype
13651385
if not isinstance(qdtype, QDType):
13661386
raise ValueError("`free` can only free quantum registers.")
13671387

@@ -1371,10 +1391,10 @@ def split(self, soq: Soquet) -> NDArray[Soquet]: # type: ignore[type-var]
13711391
"""Add a Split bloq to split up a register."""
13721392
from qualtran.bloqs.bookkeeping import Split
13731393

1374-
if not isinstance(soq, Soquet):
1394+
if not BloqBuilder.is_single(soq):
13751395
raise ValueError("`split` expects a single Soquet to split.")
13761396

1377-
qdtype = soq.reg.dtype
1397+
qdtype = soq.dtype
13781398
if not isinstance(qdtype, QDType):
13791399
raise ValueError("`split` can only split quantum registers.")
13801400

qualtran/_infra/composite_bloq_test.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Dict, List, Tuple
16+
from typing import assert_type, cast, Dict, List, Tuple
1717

1818
import attrs
1919
import networkx as nx
@@ -144,7 +144,7 @@ def test_map_soqs():
144144
assert in_soqs == bb.map_soqs(in_soqs, soq_map)
145145
elif binst.i == 1:
146146
for k, val in bb.map_soqs(in_soqs, soq_map).items():
147-
assert isinstance(val, Soquet)
147+
assert BloqBuilder.is_single(val)
148148
assert isinstance(val.binst, BloqInstance)
149149
assert val.binst.i >= 100
150150
else:
@@ -156,7 +156,7 @@ def test_map_soqs():
156156

157157
fsoqs = bb.map_soqs(cbloq.final_soqs(), soq_map)
158158
for k, val in fsoqs.items():
159-
assert isinstance(val, Soquet)
159+
assert BloqBuilder.is_single(val)
160160
assert isinstance(val.binst, BloqInstance)
161161
assert val.binst.i >= 100
162162
cbloq = bb.finalize(**fsoqs)
@@ -643,6 +643,40 @@ def test_get_soquet():
643643
_ = _get_soquet(binst=binst, reg_name='in', right=True, binst_graph=binst_graph)
644644

645645

646+
def test_can_tell_individual_from_ndsoquet():
647+
s1 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(0,))
648+
s2 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(1,))
649+
s3 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(2,))
650+
s4 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(3,))
651+
652+
# A ndarray of soquet objects should be SoquetT and we can tell by checking its shape.
653+
ndsoq: SoquetT = np.array([s1, s2, s3, s4])
654+
assert_type(ndsoq, SoquetT)
655+
assert ndsoq.shape
656+
assert ndsoq.shape == (4,)
657+
assert ndsoq.item(2) == s3
658+
with pytest.raises(ValueError, match=r'scalar'):
659+
_ = ndsoq.item()
660+
661+
# A single soquet is still a valid SoquetT, and it has a false-y shape.
662+
single_soq: SoquetT = s1
663+
assert_type(single_soq, SoquetT)
664+
assert not single_soq.shape
665+
assert single_soq.shape == ()
666+
single_soq_unwarp = single_soq.item()
667+
assert single_soq_unwarp == s1
668+
669+
# A single soquet wrapped in a 0-dim ndarray is ok if you call `item()`.
670+
single_soq2: SoquetT = np.asarray(s1)
671+
assert_type(single_soq2, SoquetT)
672+
assert not single_soq2.shape
673+
assert single_soq2.shape == ()
674+
single_soq2_unwrap = single_soq2.item()
675+
assert hash(single_soq2_unwrap) == hash(s1)
676+
assert single_soq2_unwrap == s1
677+
assert isinstance(single_soq2_unwrap, Soquet)
678+
679+
646680
@pytest.mark.notebook
647681
def test_notebook():
648682
qlt_testing.execute_notebook('composite_bloq')

qualtran/_infra/quantum_graph.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from attrs import field, frozen
2020

2121
if TYPE_CHECKING:
22-
from qualtran import Bloq, Register
22+
from qualtran import Bloq, BloqBuilder, QCDType, Register
2323

2424

2525
@frozen
@@ -103,6 +103,20 @@ def _check_idx(self, attribute, value):
103103
for i, shape in zip(value, self.reg.shape):
104104
if i >= shape:
105105
raise ValueError(f"Bad index {i} for {self.reg}.")
106+
return value
107+
108+
@property
109+
def dtype(self) -> 'QCDType':
110+
return self.reg.dtype
111+
112+
@property
113+
def shape(self) -> Tuple[int, ...]:
114+
return ()
115+
116+
def item(self, *args) -> 'Soquet':
117+
if args:
118+
raise ValueError("Tried to index into a single soquet.")
119+
return self
106120

107121
def pretty(self) -> str:
108122
label = self.reg.name

qualtran/_infra/quantum_graph_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def test_soquet():
4848
assert soq.idx == ()
4949
assert soq.pretty() == 'x'
5050

51+
assert soq.item() == soq
52+
assert soq.dtype == QAny(10)
53+
5154

5255
def test_soquet_idxed():
5356
binst = BloqInstance(TestTwoBitOp(), i=0)

qualtran/bloqs/arithmetic/subtraction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def wire_symbol(
153153

154154
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
155155
delta = self.b_dtype.bitsize - self.a_dtype.bitsize
156-
costs = {
156+
costs: Dict[Bloq, int] = {
157157
OnEach(self.b_dtype.bitsize, XGate()): 3,
158158
Add(QUInt(self.b_dtype.bitsize), QUInt(self.b_dtype.bitsize)): 1,
159159
}
@@ -196,7 +196,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', a: Soquet, b: Soquet) -> Dict[
196196
a_split[delta], prefix = bb.add(
197197
MultiTargetCNOT(delta), control=a_split[delta], targets=prefix
198198
)
199-
prefix = bb.add(Cast(prefix.reg.dtype, QAny(delta)), reg=prefix)
199+
prefix = bb.add(Cast(prefix.dtype, QAny(delta)), reg=prefix)
200200
bb.free(prefix)
201201
a = bb.join(a_split[delta:], QUInt(self.a_dtype.bitsize))
202202
a = bb.add(Cast(QUInt(self.a_dtype.bitsize), self.a_dtype), reg=a)

qualtran/bloqs/basic_gates/rotation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,7 @@ def signature(self) -> 'Signature':
220220
def build_composite_bloq(self, bb: 'BloqBuilder', q: 'SoquetT') -> Dict[str, 'SoquetT']:
221221
from qualtran.bloqs.mcmt import And
222222

223-
q1, q2 = q # type: ignore
224-
(q1, q2), anc = bb.add(And(), ctrl=[q1, q2])
223+
(q1, q2), anc = bb.add(And(), ctrl=q)
225224
anc = bb.add(ZPowGate(self.exponent, eps=self.eps), q=anc)
226225
(q1, q2) = bb.add(And().adjoint(), ctrl=[q1, q2], target=anc)
227226
return {'q': np.array([q1, q2])}

qualtran/bloqs/block_encoding/sparse_matrix.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,7 @@ def build_composite_bloq(
237237
if is_symbolic(self.system_bitsize) or is_symbolic(self.row_oracle.num_nonzero):
238238
raise DecomposeTypeError(f"Cannot decompose symbolic {self=}")
239239

240-
assert not isinstance(ancilla, np.ndarray)
241-
ancilla_bits = bb.split(ancilla)
240+
ancilla_bits = bb.split(ancilla.item())
242241
q, l = ancilla_bits[0], bb.join(ancilla_bits[1:])
243242

244243
l = bb.add(self.diffusion, target=l)

qualtran/bloqs/chemistry/pbc/first_quantization/select_and_prepare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ def _reshape_reg(
153153
"""
154154
# np.prod(()) returns a float (1.0), so take int
155155
size = int(np.prod(out_shape))
156-
if isinstance(in_reg, np.ndarray):
156+
if BloqBuilder.is_ndarray(in_reg):
157157
# split an array of bitsize qubits into flat list of qubits
158158
split_qubits = bb.split(bb.join(np.concatenate([bb.split(x) for x in in_reg.ravel()])))
159159
else:
160-
split_qubits = bb.split(in_reg)
160+
split_qubits = bb.split(in_reg.item())
161161
merged_qubits = np.array(
162162
[bb.join(split_qubits[i * bitsize : (i + 1) * bitsize]) for i in range(size)]
163163
)

qualtran/bloqs/chemistry/trotter/grid_ham/potential.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
QAny,
2828
Register,
2929
Signature,
30-
Soquet,
3130
SoquetT,
3231
)
3332
from qualtran._infra.data_types import BQUInt
@@ -93,7 +92,7 @@ def wire_symbol(
9392
def build_composite_bloq(
9493
self, bb: BloqBuilder, *, system_i: SoquetT, system_j: SoquetT
9594
) -> Dict[str, SoquetT]:
96-
if isinstance(system_i, Soquet) or isinstance(system_j, Soquet):
95+
if not (BloqBuilder.is_ndarray(system_i) and BloqBuilder.is_ndarray(system_j)):
9796
raise ValueError("system_i and system_j must be numpy arrays of Soquet")
9897
# compute r_i - r_j
9998
# r_i + (-r_j), in practice we need to flip the sign bit, but this is just 3 cliffords.
@@ -120,8 +119,8 @@ def build_composite_bloq(
120119
qrom_anc_c2 = bb.allocate(self.poly_bitsize)
121120
qrom_anc_c3 = bb.allocate(self.poly_bitsize)
122121
cast = Cast(
123-
inp_dtype=sos.reg.dtype,
124-
out_dtype=BQUInt(sos.reg.dtype.bitsize, iteration_length=len(self.qrom_data[0])),
122+
inp_dtype=sos.dtype,
123+
out_dtype=BQUInt(sos.dtype.bitsize, iteration_length=len(self.qrom_data[0])),
125124
)
126125
sos = bb.add(cast, reg=sos)
127126
qrom_bloq = QROM(
@@ -227,7 +226,7 @@ def wire_symbol(
227226
return super().wire_symbol(reg, idx)
228227

229228
def build_composite_bloq(self, bb: BloqBuilder, *, system: SoquetT) -> Dict[str, SoquetT]:
230-
if isinstance(system, Soquet):
229+
if not BloqBuilder.is_ndarray(system):
231230
raise ValueError("system must be a numpy array of Soquet")
232231
bitsize = (self.num_grid - 1).bit_length() + 1
233232
ij_pairs = np.triu_indices(self.num_elec, k=1)

qualtran/bloqs/data_loading/qroam_clean.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,8 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
520520
# Construct and return dictionary of final soquets.
521521
soqs |= {reg.name: soq for reg, soq in zip(self.control_registers, ctrl)}
522522
soqs |= {reg.name: soq for reg, soq in zip(self.selection_registers, selection)}
523-
soqs |= {reg.name: soq.flat[1:] for reg, soq in zip(self.junk_registers, qrom_targets)} # type: ignore[union-attr]
524-
soqs |= {reg.name: soq.flat[0] for reg, soq in zip(self.target_registers, qrom_targets)} # type: ignore[union-attr]
523+
soqs |= {reg.name: soq.flat[1:] for reg, soq in zip(self.junk_registers, qrom_targets)} # type: ignore[attr-defined]
524+
soqs |= {reg.name: soq.flat[0] for reg, soq in zip(self.target_registers, qrom_targets)} # type: ignore[attr-defined]
525525
return soqs
526526

527527
def on_classical_vals(

0 commit comments

Comments
 (0)