1
1
from __future__ import annotations
2
2
3
+ from collections .abc import Callable , Iterable , Iterator , Sequence
3
4
from dataclasses import field
4
- from typing import (
5
- Any ,
6
- Callable ,
7
- Dict ,
8
- Generic ,
9
- Iterable ,
10
- Iterator ,
11
- List ,
12
- Optional ,
13
- Sequence ,
14
- Type ,
15
- TypeVar ,
16
- Union ,
17
- )
5
+ from typing import Any , Generic , Literal , TypeVar , Union
18
6
19
7
import numpy as np
20
8
from pydantic import BaseConfig , Extra , Field , ValidationError , create_model
21
9
from pydantic .error_wrappers import ErrorWrapper
22
- from typing_extensions import Literal
23
10
24
11
__all__ = [
25
12
"if_instance_do" ,
@@ -43,11 +30,11 @@ class StrictConfig(BaseConfig):
43
30
44
31
45
32
def discriminated_union_of_subclasses (
46
- super_cls : Optional [ Type ] = None ,
33
+ super_cls : type | None = None ,
47
34
* ,
48
35
discriminator : str = "type" ,
49
- config : Optional [ Type [ BaseConfig ]] = None ,
50
- ) -> Union [ Type , Callable [[Type ], Type ] ]:
36
+ config : type [ BaseConfig ] | None = None ,
37
+ ) -> type | Callable [[type ], type ]:
51
38
"""Add all subclasses of super_cls to a discriminated union.
52
39
53
40
For all subclasses of super_cls, add a discriminator field to identify
@@ -114,7 +101,7 @@ def calculate(self) -> int:
114
101
subclasses. Defaults to None.
115
102
116
103
Returns:
117
- Union[ Type, Callable[[Type], Type] ]: A decorator that adds the necessary
104
+ Type | Callable[[Type], Type]: A decorator that adds the necessary
118
105
functionality to a class.
119
106
"""
120
107
@@ -130,12 +117,12 @@ def wrap(cls):
130
117
131
118
132
119
def _discriminated_union_of_subclasses (
133
- super_cls : Type ,
120
+ super_cls : type ,
134
121
discriminator : str ,
135
- config : Optional [ Type [ BaseConfig ]] = None ,
136
- ) -> Union [ Type , Callable [[Type ], Type ] ]:
137
- super_cls ._ref_classes = set ()
138
- super_cls ._model = None
122
+ config : type [ BaseConfig ] | None = None ,
123
+ ) -> type | Callable [[type ], type ]:
124
+ super_cls ._ref_classes = set () # type: ignore
125
+ super_cls ._model = None # type: ignore
139
126
140
127
def __init_subclass__ (cls ) -> None :
141
128
# Keep track of inherting classes in super class
@@ -157,7 +144,7 @@ def __validate__(cls, v: Any) -> Any:
157
144
# needs to be done once, after all subclasses have been
158
145
# declared
159
146
if cls ._model is None :
160
- root = Union [tuple (cls ._ref_classes )] # type: ignore
147
+ root = Union [tuple (cls ._ref_classes )] # type: ignore # noqa
161
148
cls ._model = create_model (
162
149
super_cls .__name__ ,
163
150
__root__ = (root , Field (..., discriminator = discriminator )),
@@ -185,7 +172,7 @@ def __validate__(cls, v: Any) -> Any:
185
172
return super_cls
186
173
187
174
188
- def if_instance_do (x : Any , cls : Type , func : Callable ):
175
+ def if_instance_do (x : Any , cls : type , func : Callable ):
189
176
"""If x is of type cls then return func(x), otherwise return NotImplemented.
190
177
191
178
Used as a helper when implementing operator overloading.
@@ -201,7 +188,7 @@ def if_instance_do(x: Any, cls: Type, func: Callable):
201
188
202
189
#: Map of axes to float ndarray of points
203
190
#: E.g. {xmotor: array([0, 1, 2]), ymotor: array([2, 2, 2])}
204
- AxesPoints = Dict [Axis , np .ndarray ]
191
+ AxesPoints = dict [Axis , np .ndarray ]
205
192
206
193
207
194
class Frames (Generic [Axis ]):
@@ -234,9 +221,9 @@ class Frames(Generic[Axis]):
234
221
def __init__ (
235
222
self ,
236
223
midpoints : AxesPoints [Axis ],
237
- lower : Optional [ AxesPoints [Axis ]] = None ,
238
- upper : Optional [ AxesPoints [Axis ]] = None ,
239
- gap : Optional [ np .ndarray ] = None ,
224
+ lower : AxesPoints [Axis ] | None = None ,
225
+ upper : AxesPoints [Axis ] | None = None ,
226
+ gap : np .ndarray | None = None ,
240
227
):
241
228
#: The midpoints of scan frames for each axis
242
229
self .midpoints = midpoints
@@ -253,7 +240,9 @@ def __init__(
253
240
# We have a gap if upper[i] != lower[i+1] for any axes
254
241
axes_gap = [
255
242
np .roll (upper , 1 ) != lower
256
- for upper , lower in zip (self .upper .values (), self .lower .values ())
243
+ for upper , lower in zip (
244
+ self .upper .values (), self .lower .values (), strict = False
245
+ )
257
246
]
258
247
self .gap = np .logical_or .reduce (axes_gap )
259
248
# Check all axes and ordering are the same
@@ -270,7 +259,7 @@ def __init__(
270
259
lengths .add (len (self .gap ))
271
260
assert len (lengths ) <= 1 , f"Mismatching lengths { list (lengths )} "
272
261
273
- def axes (self ) -> List [Axis ]:
262
+ def axes (self ) -> list [Axis ]:
274
263
"""The axes which will move during the scan.
275
264
276
265
These will be present in `midpoints`, `lower` and `upper`.
@@ -300,7 +289,7 @@ def extract_dict(ds: Iterable[AxesPoints[Axis]]) -> AxesPoints[Axis]:
300
289
return {k : v [dim_indices ] for k , v in d .items ()}
301
290
return {}
302
291
303
- def extract_gap (gaps : Iterable [np .ndarray ]) -> Optional [ np .ndarray ] :
292
+ def extract_gap (gaps : Iterable [np .ndarray ]) -> np .ndarray | None :
304
293
for gap in gaps :
305
294
if not calculate_gap :
306
295
return gap [dim_indices ]
@@ -371,7 +360,7 @@ def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:
371
360
def _merge_frames (
372
361
* stack : Frames [Axis ],
373
362
dict_merge = Callable [[Sequence [AxesPoints [Axis ]]], AxesPoints [Axis ]], # type: ignore
374
- gap_merge = Callable [[Sequence [np .ndarray ]], Optional [ np .ndarray ] ],
363
+ gap_merge = Callable [[Sequence [np .ndarray ]], np .ndarray | None ],
375
364
) -> Frames [Axis ]:
376
365
types = {type (fs ) for fs in stack }
377
366
assert len (types ) == 1 , f"Mismatching types for { stack } "
@@ -397,9 +386,9 @@ class SnakedFrames(Frames[Axis]):
397
386
def __init__ (
398
387
self ,
399
388
midpoints : AxesPoints [Axis ],
400
- lower : Optional [ AxesPoints [Axis ]] = None ,
401
- upper : Optional [ AxesPoints [Axis ]] = None ,
402
- gap : Optional [ np .ndarray ] = None ,
389
+ lower : AxesPoints [Axis ] | None = None ,
390
+ upper : AxesPoints [Axis ] | None = None ,
391
+ gap : np .ndarray | None = None ,
403
392
):
404
393
super ().__init__ (midpoints , lower = lower , upper = upper , gap = gap )
405
394
# Override first element of gap to be True, as subsequent runs
@@ -431,7 +420,7 @@ def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
431
420
length = len (self )
432
421
backwards = (indices // length ) % 2
433
422
snake_indices = np .where (backwards , (length - 1 ) - indices , indices ) % length
434
- cls : Type [Frames [Any ]]
423
+ cls : type [Frames [Any ]]
435
424
if not calculate_gap :
436
425
cls = Frames
437
426
gap = self .gap [np .where (backwards , length - indices , indices ) % length ]
@@ -464,7 +453,7 @@ def gap_between_frames(frames1: Frames[Axis], frames2: Frames[Axis]) -> bool:
464
453
return any (frames1 .upper [a ][- 1 ] != frames2 .lower [a ][0 ] for a in frames1 .axes ())
465
454
466
455
467
- def squash_frames (stack : List [Frames [Axis ]], check_path_changes = True ) -> Frames [Axis ]:
456
+ def squash_frames (stack : list [Frames [Axis ]], check_path_changes = True ) -> Frames [Axis ]:
468
457
"""Squash a stack of nested Frames into a single one.
469
458
470
459
Args:
@@ -530,7 +519,7 @@ class Path(Generic[Axis]):
530
519
"""
531
520
532
521
def __init__ (
533
- self , stack : List [Frames [Axis ]], start : int = 0 , num : Optional [ int ] = None
522
+ self , stack : list [Frames [Axis ]], start : int = 0 , num : int | None = None
534
523
):
535
524
#: The Frames stack describing the scan, from slowest to fastest moving
536
525
self .stack = stack
@@ -544,7 +533,7 @@ def __init__(
544
533
if num is not None and start + num < self .end_index :
545
534
self .end_index = start + num
546
535
547
- def consume (self , num : Optional [ int ] = None ) -> Frames [Axis ]:
536
+ def consume (self , num : int | None = None ) -> Frames [Axis ]:
548
537
"""Consume at most num frames from the Path and return as a Frames object.
549
538
550
539
>>> fx = SnakedFrames({"x": np.array([1, 2])})
@@ -613,18 +602,18 @@ class Midpoints(Generic[Axis]):
613
602
>>> fy = Frames({"y": np.array([3, 4])})
614
603
>>> mp = Midpoints([fy, fx])
615
604
>>> for p in mp: print(p)
616
- {'y': 3 , 'x': 1 }
617
- {'y': 3 , 'x': 2 }
618
- {'y': 4 , 'x': 2 }
619
- {'y': 4 , 'x': 1 }
605
+ {'y': np.int64(3) , 'x': np.int64(1) }
606
+ {'y': np.int64(3) , 'x': np.int64(2) }
607
+ {'y': np.int64(4) , 'x': np.int64(2) }
608
+ {'y': np.int64(4) , 'x': np.int64(1) }
620
609
"""
621
610
622
- def __init__ (self , stack : List [Frames [Axis ]]):
611
+ def __init__ (self , stack : list [Frames [Axis ]]):
623
612
#: The stack of Frames describing the scan, from slowest to fastest moving
624
613
self .stack = stack
625
614
626
615
@property
627
- def axes (self ) -> List [Axis ]:
616
+ def axes (self ) -> list [Axis ]:
628
617
"""The axes that will be present in each points dictionary."""
629
618
axes = []
630
619
for frames in self .stack :
@@ -635,7 +624,7 @@ def __len__(self) -> int:
635
624
"""The number of dictionaries that will be produced if iterated over."""
636
625
return int (np .prod ([len (frames ) for frames in self .stack ]))
637
626
638
- def __iter__ (self ) -> Iterator [Dict [Axis , float ]]:
627
+ def __iter__ (self ) -> Iterator [dict [Axis , float ]]:
639
628
"""Yield {axis: midpoint} for each frame in the scan."""
640
629
path = Path (self .stack )
641
630
while len (path ):
0 commit comments