Skip to content

Commit

Permalink
Put pyright in strict mode (#143)
Browse files Browse the repository at this point in the history
Enable strict mode for pyright and change code to fix type errors while preserving tests passing etc.
  • Loading branch information
callumforrester authored Sep 16, 2024
1 parent e42f8eb commit e1220f5
Show file tree
Hide file tree
Showing 17 changed files with 509 additions and 2,215 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
# ('envvar', 'LD_LIBRARY_PATH').
nitpick_ignore = [
("py:class", "scanspec.core.C"),
("py:class", "scanspec.core.T"),
("py:class", "pydantic.config.ConfigDict"),
]

Expand Down
10 changes: 3 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ name = "Tom Cobb"
version_file = "src/scanspec/_version.py"

[tool.pyright]
typeCheckingMode = "standard"
reportMissingImports = false # Ignore missing stubs in imported modules
typeCheckingMode = "strict"
reportMissingImports = false # Ignore missing stubs in imported modules

[tool.pytest.ini_options]
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error
Expand Down Expand Up @@ -135,9 +135,5 @@ convention = "google"
[tool.ruff.lint.per-file-ignores]

"tests/**/*" = [
# By default, private member access is allowed in tests
# See https://github.com/DiamondLightSource/python-copier-template/issues/154
# Remove this line to forbid private member access in tests
"SLF001",
"D", # Don't check docstrings in tests
"D", # Don't check docstrings in tests
]
1,863 changes: 1 addition & 1,862 deletions schema.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/scanspec/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
@click.version_option(prog_name="scanspec", message="%(version)s")
@click.pass_context
def cli(ctx, log_level: str):
def cli(ctx: click.Context, log_level: str):
"""Top level scanspec command line interface."""
level = getattr(logging, log_level.upper(), None)
logging.basicConfig(format="%(levelname)s:%(message)s", level=level)
Expand Down Expand Up @@ -50,7 +50,7 @@ def plot(spec: str):
@click.option(
"--port", default=8080, help="The port that the scanspec service will be hosted on."
)
def service(cors, port):
def service(cors: bool, port: int):
"""Run up a REST service."""
from scanspec.service import run_app

Expand Down
109 changes: 66 additions & 43 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,28 @@

from __future__ import annotations

import itertools
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import lru_cache
from inspect import isclass
from typing import Any, Generic, Literal, TypeVar
from typing import (
Any,
Generic,
Literal,
TypeVar,
)

import numpy as np
import numpy.typing as npt
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
from pydantic.dataclasses import is_pydantic_dataclass, rebuild_dataclass
from pydantic_core import CoreSchema
from pydantic_core.core_schema import tagged_union_schema

__all__ = [
"if_instance_do",
"Axis",
"OtherAxis",
"if_instance_do",
"AxesPoints",
"Frames",
"SnakedFrames",
Expand All @@ -31,6 +39,9 @@
StrictConfig: ConfigDict = {"extra": "forbid"}

C = TypeVar("C")
T = TypeVar("T")

GapArray = npt.NDArray[np.bool]


def discriminated_union_of_subclasses(
Expand Down Expand Up @@ -111,7 +122,7 @@ def calculate(self) -> int:
tagged_union = _TaggedUnion(super_cls, discriminator)
_tagged_unions[super_cls] = tagged_union

def add_subclass_to_union(subclass):
def add_subclass_to_union(subclass: type[C]):
# Add a discriminator field to a subclass so it can
# be identified when deserializing
subclass.__annotations__ = {
Expand All @@ -120,7 +131,9 @@ def add_subclass_to_union(subclass):
}
setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore

def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
def get_schema_of_union(
cls: type[C], source_type: Any, handler: GetCoreSchemaHandler
):
if cls is not super_cls:
tagged_union.add_member(cls)
return handler(cls)
Expand All @@ -138,7 +151,7 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):


class _TaggedUnion:
def __init__(self, base_class: type, discriminator: str):
def __init__(self, base_class: type[Any], discriminator: str):
self._base_class = base_class
# Classes and their field names that refer to this tagged union
self._discriminator = discriminator
Expand All @@ -154,7 +167,7 @@ def add_member(self, cls: type):
_TaggedUnion._rebuild(member)

@staticmethod
def _rebuild(cls_or_func: type | Callable):
def _rebuild(cls_or_func: Callable[..., T]) -> None:
if isclass(cls_or_func):
if is_pydantic_dataclass(cls_or_func):
rebuild_dataclass(cls_or_func, force=True)
Expand All @@ -170,11 +183,13 @@ def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema:


@lru_cache(1)
def _make_schema(members: tuple[type, ...], handler):
def _make_schema(
members: tuple[type[Any], ...], handler: Callable[[Any], CoreSchema]
) -> dict[str, CoreSchema]:
return {member.__name__: handler(member) for member in members}


def if_instance_do(x: Any, cls: type, func: Callable):
def if_instance_do(x: C, cls: type[C], func: Callable[[C], T]) -> T:
"""If x is of type cls then return func(x), otherwise return NotImplemented.
Used as a helper when implementing operator overloading.
Expand All @@ -186,11 +201,14 @@ def if_instance_do(x: Any, cls: type, func: Callable):


#: A type variable for an `axis_` that can be specified for a scan
Axis = TypeVar("Axis")
Axis = TypeVar("Axis", covariant=True)

#: Alternative axis variable to be used when two are required in the same type binding
OtherAxis = TypeVar("OtherAxis")

#: Map of axes to float ndarray of points
#: E.g. {xmotor: array([0, 1, 2]), ymotor: array([2, 2, 2])}
AxesPoints = dict[Axis, np.ndarray]
AxesPoints = dict[Axis, npt.NDArray[np.floating[Any]]]


class Frames(Generic[Axis]):
Expand Down Expand Up @@ -226,7 +244,7 @@ def __init__(
midpoints: AxesPoints[Axis],
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
gap: GapArray | None = None,
):
#: The midpoints of scan frames for each axis
self.midpoints = midpoints
Expand Down Expand Up @@ -274,7 +292,9 @@ def __len__(self) -> int:
# All axespoints arrays are same length, pick the first one
return len(self.gap)

def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
def extract(
self, indices: npt.NDArray[np.signedinteger[Any]], calculate_gap: bool = True
) -> Frames[Axis]:
"""Return a new Frames object restricted to the indices provided.
Args:
Expand All @@ -293,7 +313,7 @@ def extract_dict(ds: Iterable[AxesPoints[Axis]]) -> AxesPoints[Axis]:
return {k: v[dim_indices] for k, v in d.items()}
return {}

def extract_gap(gaps: Iterable[np.ndarray]) -> np.ndarray | None:
def extract_gap(gaps: Iterable[GapArray]) -> GapArray | None:
for gap in gaps:
if not calculate_gap:
return gap[dim_indices]
Expand Down Expand Up @@ -326,7 +346,7 @@ def concat_dict(ds: Sequence[AxesPoints[Axis]]) -> AxesPoints[Axis]:
# lower[ax] = np.concatenate(self.lower[ax], other.lower[ax])
return {a: np.concatenate([d[a] for d in ds]) for a in self.axes()}

def concat_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:
def concat_gap(gaps: Sequence[GapArray]) -> GapArray:
g = np.concatenate(gaps)
# Calc the first frame
g[0] = gap_between_frames(other, self)
Expand Down Expand Up @@ -354,7 +374,7 @@ def zip_dict(ds: Sequence[AxesPoints[Axis]]) -> AxesPoints[Axis]:
# lower[ax] = {**self.lower[ax], **other.lower[ax]}
return dict(kv for d in ds for kv in d.items())

def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:
def zip_gap(gaps: Sequence[GapArray]) -> GapArray:
# Gap if either frames has a gap. E.g.
# gap[i] = self.gap[i] | other.gap[i]
return np.logical_or.reduce(gaps)
Expand All @@ -364,24 +384,24 @@ def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:

def _merge_frames(
*stack: Frames[Axis],
dict_merge=Callable[[Sequence[AxesPoints[Axis]]], AxesPoints[Axis]], # type: ignore
gap_merge=Callable[[Sequence[np.ndarray]], np.ndarray | None],
dict_merge: Callable[[Sequence[AxesPoints[Axis]]], AxesPoints[Axis]], # type: ignore
gap_merge: Callable[[Sequence[GapArray]], GapArray | None],
) -> Frames[Axis]:
types = {type(fs) for fs in stack}
assert len(types) == 1, f"Mismatching types for {stack}"
cls = types.pop()

# If any lower or upper are different, apply to those
kwargs = {}
for a in ("lower", "upper"):
if any(fs.midpoints is not getattr(fs, a) for fs in stack):
kwargs[a] = dict_merge([getattr(fs, a) for fs in stack])

# Apply to midpoints, force calculation of gap
return cls(
midpoints=dict_merge([fs.midpoints for fs in stack]),
gap=gap_merge([fs.gap for fs in stack]),
**kwargs,
# If any lower or upper are different, apply to those
lower=dict_merge([fs.lower for fs in stack])
if any(fs.midpoints is not fs.lower for fs in stack)
else None,
upper=dict_merge([fs.upper for fs in stack])
if any(fs.midpoints is not fs.upper for fs in stack)
else None,
)


Expand All @@ -393,19 +413,23 @@ def __init__(
midpoints: AxesPoints[Axis],
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
gap: GapArray | None = None,
):
super().__init__(midpoints, lower=lower, upper=upper, gap=gap)
# Override first element of gap to be True, as subsequent runs
# of snake scans are always joined end -> start
self.gap[0] = False

@classmethod
def from_frames(cls, frames: Frames[Axis]) -> SnakedFrames[Axis]:
def from_frames(
cls: type[SnakedFrames[Any]], frames: Frames[OtherAxis]
) -> SnakedFrames[OtherAxis]:
"""Create a snaked version of a `Frames` object."""
return cls(frames.midpoints, frames.lower, frames.upper, frames.gap)

def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
def extract(
self, indices: npt.NDArray[np.signedinteger[Any]], calculate_gap: bool = True
) -> Frames[Axis]:
"""Return a new Frames object restricted to the indices provided.
Args:
Expand Down Expand Up @@ -434,23 +458,23 @@ def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
cls = type(self)
gap = None

# If lower or upper are different, apply to those
kwargs = {}
if self.midpoints is not self.lower:
# If going backwards select from the opposite bound
kwargs["lower"] = {
# Apply to midpoints
return cls(
{k: v[snake_indices] for k, v in self.midpoints.items()},
gap=gap,
# If lower or upper are different, apply to those
lower={
k: np.where(backwards, self.upper[k][snake_indices], v[snake_indices])
for k, v in self.lower.items()
}
if self.midpoints is not self.upper:
kwargs["upper"] = {
if self.midpoints is not self.lower
else None,
upper={
k: np.where(backwards, self.lower[k][snake_indices], v[snake_indices])
for k, v in self.upper.items()
}

# Apply to midpoints
return cls(
{k: v[snake_indices] for k, v in self.midpoints.items()}, gap=gap, **kwargs
if self.midpoints is not self.upper
else None,
)


Expand All @@ -459,7 +483,9 @@ def gap_between_frames(frames1: Frames[Axis], frames2: Frames[Axis]) -> bool:
return any(frames1.upper[a][-1] != frames2.lower[a][0] for a in frames1.axes())


def squash_frames(stack: list[Frames[Axis]], check_path_changes=True) -> Frames[Axis]:
def squash_frames(
stack: list[Frames[Axis]], check_path_changes: bool = True
) -> Frames[Axis]:
"""Squash a stack of nested Frames into a single one.
Args:
Expand Down Expand Up @@ -624,10 +650,7 @@ def __init__(self, stack: list[Frames[Axis]]):
@property
def axes(self) -> list[Axis]:
"""The axes that will be present in each points dictionary."""
axes = []
for frames in self.stack:
axes += frames.axes()
return axes
return list(itertools.chain(*(frames.axes() for frames in self.stack)))

def __len__(self) -> int:
"""The number of dictionaries that will be produced if iterated over."""
Expand Down
Loading

0 comments on commit e1220f5

Please sign in to comment.