Skip to content

Commit 0c43e2e

Browse files
Move to pyright and fix type errors
1 parent d5cce50 commit 0c43e2e

File tree

6 files changed

+43
-32
lines changed

6 files changed

+43
-32
lines changed

pyproject.toml

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@ classifiers = [
1111
"Programming Language :: Python :: 3.11",
1212
]
1313
description = "Specify step and flyscan paths in a serializable, efficient and Pythonic way"
14-
dependencies = [
15-
"numpy>=2",
16-
"click>=8.1",
17-
"pydantic>=2.0",
18-
]
14+
dependencies = ["numpy>=2", "click>=8.1", "pydantic>=2.0"]
1915
dynamic = ["version"]
2016
license.file = "LICENSE"
2117
readme = "README.md"
@@ -33,11 +29,11 @@ dev = [
3329
"scanspec[plotting]",
3430
"scanspec[service]",
3531
"copier",
36-
"mypy",
3732
"myst-parser",
3833
"pipdeptree",
3934
"pre-commit",
4035
"pydata-sphinx-theme>=0.12",
36+
"pyright",
4137
"pytest",
4238
"pytest-cov",
4339
"ruff",
@@ -65,8 +61,9 @@ name = "Tom Cobb"
6561
[tool.setuptools_scm]
6662
write_to = "src/scanspec/_version.py"
6763

68-
[tool.mypy]
69-
ignore_missing_imports = true # Ignore missing stubs in imported modules
64+
[tool.pyright]
65+
# strict = ["src", "tests"]
66+
reportMissingImports = false # Ignore missing stubs in imported modules
7067

7168
[tool.pytest.ini_options]
7269
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error
@@ -99,12 +96,12 @@ passenv = *
9996
allowlist_externals =
10097
pytest
10198
pre-commit
102-
mypy
99+
pyright
103100
sphinx-build
104101
sphinx-autobuild
105102
commands =
106103
pre-commit: pre-commit run --all-files {posargs}
107-
type-checking: mypy src tests {posargs}
104+
type-checking: pyright src tests {posargs}
108105
tests: pytest --cov=scanspec --cov-report term --cov-report xml:cov.xml {posargs}
109106
docs: sphinx-{posargs:build -E --keep-going} -T docs build/html
110107
"""
@@ -115,14 +112,21 @@ line-length = 88
115112

116113
[tool.ruff.lint]
117114
extend-select = [
118-
"B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
119-
"C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4
120-
"E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e
121-
"F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f
122-
"W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w
123-
"I", # isort - https://docs.astral.sh/ruff/rules/#isort-i
124-
"UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up
115+
"B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
116+
"C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4
117+
"E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e
118+
"F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f
119+
"W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w
120+
"I", # isort - https://docs.astral.sh/ruff/rules/#isort-i
121+
"UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up
122+
"SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self
125123
]
126124
ignore = [
127125
"B008", # We use function calls in service arguments
128126
]
127+
128+
[tool.ruff.lint.per-file-ignores]
129+
# By default, private member access is allowed in tests
130+
# See https://github.com/DiamondLightSource/python-copier-template/issues/154
131+
# Remove this line to forbid private member access in tests
132+
"tests/**/*" = ["SLF001"]

src/scanspec/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def cli(ctx, log_level: str):
2525

2626
# if no command is supplied, print the help message
2727
if ctx.invoked_subcommand is None:
28-
click.echo(cli.get_help(ctx))
28+
click.echo(cli.get_help(ctx)) # type: ignore
2929

3030

3131
@cli.command()

src/scanspec/core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,14 @@
3535

3636
StrictConfig: ConfigDict = {"extra": "forbid"}
3737

38+
C = TypeVar("C")
39+
T = TypeVar("T", type, Callable)
40+
3841

3942
def discriminated_union_of_subclasses(
40-
super_cls: type,
43+
super_cls: type[C],
4144
discriminator: str = "type",
42-
) -> type:
45+
) -> type[C]:
4346
"""Add all subclasses of super_cls to a discriminated union.
4447
4548
For all subclasses of super_cls, add a discriminator field to identify
@@ -137,9 +140,6 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
137140
return super_cls
138141

139142

140-
T = TypeVar("T", type, Callable)
141-
142-
143143
def uses_tagged_union(cls_or_func: T) -> T:
144144
"""
145145
T = TypeVar("T", type, Callable)
@@ -616,7 +616,7 @@ def consume(self, num: int | None = None) -> Frames[Axis]:
616616

617617
def __len__(self) -> int:
618618
"""Number of frames left in a scan, reduces when `consume` is called."""
619-
return self.end_index - self.index
619+
return int(self.end_index - self.index)
620620

621621

622622
class Midpoints(Generic[Axis]):

src/scanspec/plot.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, xs, ys, zs, *args, **kwargs):
3333
# Added here because of https://github.com/matplotlib/matplotlib/issues/21688
3434
def do_3d_projection(self, renderer=None):
3535
xs3d, ys3d, zs3d = self._verts3d
36-
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
36+
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) # type: ignore
3737
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
3838

3939
return np.min(zs)
@@ -109,11 +109,17 @@ def plot_spec(spec: Spec[Any], title: str | None = None):
109109
# Setup axes
110110
if ndims > 2:
111111
plt.figure(figsize=(6, 6))
112-
plt_axes: Axes3D = plt.axes(projection="3d")
112+
plt_axes = plt.axes(projection="3d")
113113
plt_axes.grid(False)
114-
plt_axes.set_zlabel(axes[-3])
115-
plt_axes.set_ylabel(axes[-2])
116-
plt_axes.view_init(elev=15)
114+
if isinstance(plt_axes, Axes3D):
115+
plt_axes.set_zlabel(axes[-3])
116+
plt_axes.set_ylabel(axes[-2])
117+
plt_axes.view_init(elev=15)
118+
else:
119+
raise TypeError(
120+
"Expected matplotlib to create an Axes3D object, "
121+
f"instead got: {plt_axes}"
122+
)
117123
elif ndims == 2:
118124
plt.figure(figsize=(6, 6))
119125
plt_axes = plt.axes()
@@ -208,7 +214,7 @@ def plot_spec(spec: Spec[Any], title: str | None = None):
208214
_plot_arrow(plt_axes, arrow_arr)
209215
elif splines:
210216
# Plot the starting arrow in the direction of the first point
211-
arrow_arr = [(2 * a[0] - a[1], a[0]) for a in splines[0]]
217+
arrow_arr = [np.array([2 * a[0] - a[1], a[0]]) for a in splines[0]]
212218
_plot_arrow(plt_axes, arrow_arr)
213219
else:
214220
# First point isn't moving, put a right caret marker

src/scanspec/sphinxext.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import contextmanager
22

3+
from docutils.statemachine import StringList
34
from matplotlib.sphinxext import plot_directive
45

56
from . import __version__
@@ -25,7 +26,7 @@ class ExampleSpecDirective(plot_directive.PlotDirective):
2526
"""Runs `plot_spec` on the ``spec`` definied in the content."""
2627

2728
def run(self):
28-
self.content = (
29+
self.content = StringList(
2930
["# Example Spec", "", "from scanspec.plot import plot_spec"]
3031
+ [str(x) for x in self.content]
3132
+ ["plot_spec(spec)"]

tests/test_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def test_gap_repeat() -> None:
492492

493493
def test_gap_repeat_non_snake() -> None:
494494
# Check that no gap doesn't propogate to dim.gap for non-snaked axis
495-
spec: Spec[Any] = Repeat(3, gap=False) * Line.bounded(x, 11, 19, 1)
495+
spec: Spec[str] = Repeat(3, gap=False) * Line.bounded(x, 11, 19, 1)
496496
dim = spec.frames()
497497
assert len(dim) == 3
498498
assert dim.lower == {x: pytest.approx([11, 11, 11])}

0 commit comments

Comments
 (0)