Skip to content

Commit

Permalink
Updates after main merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Aug 7, 2024
1 parent 2e188d3 commit b0df220
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 63 deletions.
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ description = "Specify step and flyscan paths in a serializable, efficient and P
dependencies = [
"numpy>=2",
"click>=8.1",
"pydantic>2.0",
"httpx==0.26.0",
"typing_extensions",
"pydantic>=2.0",
]
dynamic = ["version"]
license.file = "LICENSE"
Expand Down
32 changes: 5 additions & 27 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
from __future__ import annotations

import dataclasses
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import partial
from inspect import isclass
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Sequence,
Type,
Literal,
TypeVar,
Union,
get_origin,
get_type_hints,
)
Expand All @@ -28,16 +22,6 @@
)
from pydantic.dataclasses import rebuild_dataclass
from pydantic.fields import FieldInfo
from typing_extensions import Literal

from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import field
from typing import Any, Generic, Literal, TypeVar, Union

import numpy as np
from pydantic import BaseConfig, Extra, Field, ValidationError, create_model
from pydantic.error_wrappers import ErrorWrapper


__all__ = [
"if_instance_do",
Expand All @@ -54,14 +38,13 @@
]


StrictConfig: ConfigDict = {"extra": "forbid"}
StrictConfig: ConfigDict = ConfigDict(extra="forbid")


def discriminated_union_of_subclasses(
cls,
discriminator: str = "type",
):

"""Add all subclasses of super_cls to a discriminated union.
For all subclasses of super_cls, add a discriminator field to identify
Expand Down Expand Up @@ -142,13 +125,10 @@ def calculate(self) -> int:
T = TypeVar("T", type, Callable)




def deserialize_as(cls, obj):
return TypeAdapter(_tagged_unions[cls]._make_union()).validate_python(obj)



def uses_tagged_union(cls_or_func: T) -> T:
for k, v in get_type_hints(cls_or_func).items():
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
Expand All @@ -157,8 +137,6 @@ def uses_tagged_union(cls_or_func: T) -> T:
return cls_or_func




class _TaggedUnion:
def __init__(self, base_class: type, discriminator: str):
self._base_class = base_class
Expand All @@ -174,7 +152,7 @@ def _make_union(self):
# https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators
if len(self._members) > 1:
# Unions are only valid with more than 1 member
return Union[tuple(self._members)] # type: ignore
return tuple(self._members) # type: ignore

def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
# Set the field to use the `type` discriminator on deserialize
Expand Down
2 changes: 0 additions & 2 deletions src/scanspec/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,8 @@ def _plot_spline(axes, ranges, arrays: list[np.ndarray], index_colours: dict[int
yield unscaled_splines



@uses_tagged_union
def plot_spec(spec: Spec[Any], title: str | None = None):

"""Plot a spec, drawing the path taken through the scan.
Uses a different colour for each frame, grey for the turnarounds, and
Expand Down
21 changes: 9 additions & 12 deletions src/scanspec/regions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from __future__ import annotations


from dataclasses import is_dataclass
from typing import Generic, Iterator, List, Set
from collections.abc import Iterator
from dataclasses import is_dataclass
from typing import Generic


import numpy as np
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
Expand Down Expand Up @@ -112,7 +109,7 @@ class Range(Region[Axis]):
min: float = Field(description="The minimum inclusive value in the region")
max: float = Field(description="The minimum inclusive value in the region")

def axis_sets(self) -> List[Set[Axis]]:
def axis_sets(self) -> list[set[Axis]]:
return [{self.axis}]

def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
Expand Down Expand Up @@ -144,7 +141,7 @@ class Rectangle(Region[Axis]):
description="Clockwise rotation angle of the rectangle", default=0.0
)

def axis_sets(self) -> List[Set[Axis]]:
def axis_sets(self) -> list[set[Axis]]:
return [{self.x_axis, self.y_axis}]

def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
Expand Down Expand Up @@ -177,22 +174,22 @@ class Polygon(Region[Axis]):

x_axis: Axis = Field(description="The name matching the x axis of the spec")
y_axis: Axis = Field(description="The name matching the y axis of the spec")
x_verts: List[float] = Field(
x_verts: list[float] = Field(
description="The Nx1 x coordinates of the polygons vertices", min_length=3
)
y_verts: List[float] = Field(
y_verts: list[float] = Field(
description="The Nx1 y coordinates of the polygons vertices", min_length=3
)

def axis_sets(self) -> List[Set[Axis]]:
def axis_sets(self) -> list[set[Axis]]:
return [{self.x_axis, self.y_axis}]

def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
x = points[self.x_axis]
y = points[self.y_axis]
v1x, v1y = self.x_verts[-1], self.y_verts[-1]
mask = np.full(len(x), False, dtype=np.int8)
for v2x, v2y in zip(self.x_verts, self.y_verts):
for v2x, v2y in zip(self.x_verts, self.y_verts, strict=False):
# skip horizontal edges
if v2y != v1y:
vmask = np.full(len(x), False, dtype=np.int8)
Expand Down Expand Up @@ -224,7 +221,7 @@ class Circle(Region[Axis]):
y_middle: float = Field(description="The central y point of the circle")
radius: float = Field(description="Radius of the circle", gt=0)

def axis_sets(self) -> List[Set[Axis]]:
def axis_sets(self) -> list[set[Axis]]:
return [{self.x_axis, self.y_axis}]

def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
Expand Down Expand Up @@ -259,7 +256,7 @@ class Ellipse(Region[Axis]):
)
angle: float = Field(description="The angle of the ellipse (degrees)", default=0.0)

def axis_sets(self) -> List[Set[Axis]]:
def axis_sets(self) -> list[set[Axis]]:
return [{self.x_axis, self.y_axis}]

def mask(self, points: AxesPoints[Axis]) -> np.ndarray:
Expand Down
31 changes: 12 additions & 19 deletions src/scanspec/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@
from dataclasses import asdict
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Tuple,
Type,
)

import numpy as np
Expand Down Expand Up @@ -153,14 +146,14 @@ class Concat(Spec[Axis]):
default=True,
)

def axes(self) -> List:
def axes(self) -> list:
left_axes, right_axes = self.left.axes(), self.right.axes()
# Assuming the axes are the same, the order does not matter, we inherit the
# order from the left-hand side. See also scanspec.core.concat.
assert set(left_axes) == set(right_axes), f"axes {left_axes} != {right_axes}"
return left_axes

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]:
dim_left = squash_frames(
self.left.calculate(bounds, nested), nested and self.check_path_changes
)
Expand All @@ -187,10 +180,10 @@ class Line(Spec[Axis]):
stop: float = Field(description="Midpoint of the last point of the line")
num: int = Field(ge=1, description="Number of frames to produce")

def axes(self) -> List:
def axes(self) -> list:
return [self.axis]

def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
def _line_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]:
if self.num == 1:
# Only one point, stop-start gives length of one point
step = self.stop - self.start
Expand All @@ -202,7 +195,7 @@ def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
first = self.start - step / 2
return {self.axis: indexes * step + first}

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]:
return _dimensions_from_indexes(
self._line_from_indexes, self.axes(), self.num, bounds
)
Expand Down Expand Up @@ -256,7 +249,7 @@ class Static(Spec[Axis]):

@classmethod
def duration(
cls: Type[Static],
cls: type[Static],
duration: float = Field(description="The duration of each static point"),
num: int = Field(ge=1, description="Number of frames to produce", default=1),
) -> Static[str]:
Expand All @@ -270,13 +263,13 @@ def duration(
"""
return cls(DURATION, duration, num)

def axes(self) -> List:
def axes(self) -> list:
return [self.axis]

def _repeats_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
def _repeats_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]:
return {self.axis: np.full(len(indexes), self.value)}

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]:
return _dimensions_from_indexes(
self._repeats_from_indexes, self.axes(), self.num, bounds
)
Expand Down Expand Up @@ -312,11 +305,11 @@ class Spiral(Spec[Axis]):
description="How much to rotate the angle of the spiral", default=0.0
)

def axes(self) -> List[Axis]:
def axes(self) -> list[Axis]:
# TODO: reversed from __init__ args, a good idea?
return [self.y_axis, self.x_axis]

def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
def _spiral_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]:
# simplest spiral equation: r = phi
# we want point spacing across area to be the same as between rings
# so: sqrt(area / num) = ring_spacing
Expand All @@ -333,7 +326,7 @@ def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
self.x_axis: self.x_start + x_scale * phi * np.sin(phi + self.rotate),
}

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]:
return _dimensions_from_indexes(
self._spiral_from_indexes, self.axes(), self.num, bounds
)
Expand Down

0 comments on commit b0df220

Please sign in to comment.