2626 Mapping ,
2727 Optional ,
2828 overload ,
29+ Protocol ,
2930 Sequence ,
3031 Set ,
3132 Tuple ,
3233 TYPE_CHECKING ,
34+ TypeGuard ,
3335 TypeVar ,
3436 Union ,
3537)
5557 from qualtran .simulation .classical_sim import ClassicalValT
5658 from qualtran .symbolics import SymbolicInt
5759
58- # NDArrays must be bound to np.generic
59- _SoquetType = TypeVar ('_SoquetType' , bound = np .generic )
6060
61- SoquetT = Union [Soquet , NDArray [_SoquetType ]]
62- """A `Soquet` or array of soquets."""
61+ class SoquetT (Protocol ):
62+ @property
63+ def shape (self ) -> Tuple [int , ...]: ...
6364
64- SoquetInT = Union [Soquet , NDArray [_SoquetType ], Sequence [Soquet ]]
65+ def item (self , * args ) -> Soquet : ...
66+
67+
68+ SoquetInT = Union [SoquetT , Sequence [SoquetT ]]
6569"""A soquet or array-like of soquets.
6670
6771This type alias is used for input argument to parts of the library that are more
@@ -693,9 +697,10 @@ def _flatten_soquet_collection(vals: Iterable[SoquetT]) -> List[Soquet]:
693697 """
694698 soqvals = []
695699 for soq_or_arr in vals :
696- if isinstance (soq_or_arr , Soquet ):
697- soqvals .append (soq_or_arr )
700+ if BloqBuilder . is_single (soq_or_arr ):
701+ soqvals .append (soq_or_arr . item () )
698702 else :
703+ assert BloqBuilder .is_ndarray (soq_or_arr )
699704 soqvals .extend (soq_or_arr .reshape (- 1 ))
700705 return soqvals
701706
@@ -802,13 +807,10 @@ def _process_soquets(
802807 unchecked_names .remove (reg .name ) # so we can check for surplus arguments.
803808
804809 for li in reg .all_idxs ():
805- idxed_soq = in_soq [li ]
806- assert isinstance (idxed_soq , Soquet ), idxed_soq
810+ idxed_soq = in_soq [li ].item ()
807811 func (idxed_soq , reg , li )
808- if not check_dtypes_consistent (idxed_soq .reg .dtype , reg .dtype ):
809- extra_str = (
810- f"{ idxed_soq .reg .name } : { idxed_soq .reg .dtype } vs { reg .name } : { reg .dtype } "
811- )
812+ if not check_dtypes_consistent (idxed_soq .dtype , reg .dtype ):
813+ extra_str = f"{ idxed_soq .reg .name } : { idxed_soq .dtype } vs { reg .name } : { reg .dtype } "
812814 raise BloqError (
813815 f"{ debug_str } register dtypes are not consistent { extra_str } ."
814816 ) from None
@@ -838,9 +840,9 @@ def _map_soqs(
838840 # First: flatten out any numpy arrays
839841 flat_soq_map : Dict [Soquet , Soquet ] = {}
840842 for old_soqs , new_soqs in soq_map :
841- if isinstance (old_soqs , Soquet ):
842- assert isinstance (new_soqs , Soquet ), new_soqs
843- flat_soq_map [old_soqs ] = new_soqs
843+ if BloqBuilder . is_single (old_soqs ):
844+ assert BloqBuilder . is_single (new_soqs ), new_soqs
845+ flat_soq_map [old_soqs ] = new_soqs . item ()
844846 continue
845847
846848 assert isinstance (old_soqs , np .ndarray ), old_soqs
@@ -858,9 +860,9 @@ def _map_soq(soq: Soquet) -> Soquet:
858860 vmap = np .vectorize (_map_soq , otypes = [object ])
859861
860862 def _map_soqs (soqs : SoquetT ) -> SoquetT :
861- if isinstance (soqs , Soquet ):
862- return _map_soq (soqs )
863- return vmap (soqs )
863+ if BloqBuilder . is_ndarray (soqs ):
864+ return vmap (soqs )
865+ return _map_soq (soqs . item () )
864866
865867 return {name : _map_soqs (soqs ) for name , soqs in soqs .items ()}
866868
@@ -1061,6 +1063,24 @@ def from_signature(
10611063
10621064 return bb , initial_soqs
10631065
1066+ @staticmethod
1067+ def is_single (x : 'SoquetT' ) -> TypeGuard ['Soquet' ]:
1068+ """Returns True if `x` is a single soquet (not an ndarray of them).
1069+
1070+ This doesn't use stringent runtime type checking; it uses the SoquetT protocol
1071+ for "duck typing".
1072+ """
1073+ return x .shape == ()
1074+
1075+ @staticmethod
1076+ def is_ndarray (x : 'SoquetT' ) -> TypeGuard ['NDArray' ]:
1077+ """Returns True if `x` is an ndarray of soquets (not a single one).
1078+
1079+ This doesn't use stringent runtime type checking; it uses the SoquetT protocol
1080+ for "duck typing".
1081+ """
1082+ return x .shape != ()
1083+
10641084 @staticmethod
10651085 def map_soqs (
10661086 soqs : Dict [str , SoquetT ], soq_map : Iterable [Tuple [SoquetT , SoquetT ]]
@@ -1265,8 +1285,7 @@ def add_from(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[SoquetT, ...]:
12651285 cbloq = bloq .decompose_bloq ()
12661286
12671287 for k , v in in_soqs .items ():
1268- if not isinstance (v , Soquet ):
1269- in_soqs [k ] = np .asarray (v )
1288+ in_soqs [k ] = np .asarray (v )
12701289
12711290 # Initial mapping of LeftDangle according to user-provided in_soqs.
12721291 soq_map : List [Tuple [SoquetT , SoquetT ]] = [
@@ -1306,12 +1325,13 @@ def finalize(self, **final_soqs: SoquetT) -> CompositeBloq:
13061325
13071326 def _infer_reg (name : str , soq : SoquetT ) -> Register :
13081327 """Go from Soquet -> register, but use a specific name for the register."""
1309- if isinstance (soq , Soquet ):
1310- return Register (name = name , dtype = soq .reg .dtype , side = Side .RIGHT )
1328+ if BloqBuilder .is_single (soq ):
1329+ return Register (name = name , dtype = soq .dtype , side = Side .RIGHT )
1330+ assert BloqBuilder .is_ndarray (soq )
13111331
13121332 # Get info from 0th soquet in an ndarray.
13131333 return Register (
1314- name = name , dtype = soq .reshape (- 1 )[ 0 ]. reg .dtype , shape = soq .shape , side = Side .RIGHT
1334+ name = name , dtype = soq .reshape (- 1 ). item ( 0 ) .dtype , shape = soq .shape , side = Side .RIGHT
13151335 )
13161336
13171337 right_reg_names = [reg .name for reg in self ._regs if reg .side & Side .RIGHT ]
@@ -1358,10 +1378,10 @@ def allocate(
13581378 def free (self , soq : Soquet , dirty : bool = False ) -> None :
13591379 from qualtran .bloqs .bookkeeping import Free
13601380
1361- if not isinstance (soq , Soquet ):
1381+ if not BloqBuilder . is_single (soq ):
13621382 raise ValueError ("`free` expects a single Soquet to free." )
13631383
1364- qdtype = soq .reg . dtype
1384+ qdtype = soq .dtype
13651385 if not isinstance (qdtype , QDType ):
13661386 raise ValueError ("`free` can only free quantum registers." )
13671387
@@ -1371,10 +1391,10 @@ def split(self, soq: Soquet) -> NDArray[Soquet]: # type: ignore[type-var]
13711391 """Add a Split bloq to split up a register."""
13721392 from qualtran .bloqs .bookkeeping import Split
13731393
1374- if not isinstance (soq , Soquet ):
1394+ if not BloqBuilder . is_single (soq ):
13751395 raise ValueError ("`split` expects a single Soquet to split." )
13761396
1377- qdtype = soq .reg . dtype
1397+ qdtype = soq .dtype
13781398 if not isinstance (qdtype , QDType ):
13791399 raise ValueError ("`split` can only split quantum registers." )
13801400
0 commit comments