Skip to content

Commit fe6f185

Browse files
chore(tidy3d): FXC-4318-add-mypy-typedefs-in-rest-of-repo-except-big-component-files
1 parent de35229 commit fe6f185

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+493
-349
lines changed

pyproject.toml

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -341,17 +341,15 @@ python_files = "*.py"
341341
[tool.mypy]
342342
python_version = "3.10"
343343
files = [
344-
"tidy3d/components/autograd",
345-
"tidy3d/components/data",
346-
"tidy3d/components/geometry",
347-
"tidy3d/components/grid",
348-
"tidy3d/components/mode",
349-
"tidy3d/components/tcad",
350-
"tidy3d/components/viz",
351-
"tidy3d/config",
352-
"tidy3d/material_library",
353-
"tidy3d/plugins",
354-
"tidy3d/web",
344+
"tidy3d",
345+
]
346+
exclude = [
347+
"tidy3d/components/simulation.py",
348+
"tidy3d/components/validators.py",
349+
"tidy3d/components/medium.py",
350+
"tidy3d/components/boundary.py",
351+
"tidy3d/components/base.py",
352+
"tidy3d/components/monitor.py",
355353
]
356354
ignore_missing_imports = true
357355
follow_imports = "skip"

tidy3d/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tidy3d.web import Job
1010

1111

12-
def main(args) -> None:
12+
def main(args: list[str]) -> None:
1313
"""Parse args and run the corresponding tidy3d simulaton."""
1414

1515
parser = argparse.ArgumentParser(description="Tidy3D")

tidy3d/components/base_sim/data/sim_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import xarray as xr
1010
from pydantic import Field, field_validator, model_validator
1111

12+
from tidy3d.compat import Self
1213
from tidy3d.components.base import Tidy3dBaseModel
1314
from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData
1415
from tidy3d.components.base_sim.simulation import AbstractSimulation
@@ -51,7 +52,7 @@ def monitor_data(self) -> dict[str, AbstractMonitorData]:
5152
return {monitor_data.monitor.name: monitor_data for monitor_data in self.data}
5253

5354
@model_validator(mode="after")
54-
def data_monitors_match_sim(self):
55+
def data_monitors_match_sim(self) -> Self:
5556
"""Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in
5657
``.simulation``.
5758
"""
@@ -70,7 +71,9 @@ def data_monitors_match_sim(self):
7071

7172
@field_validator("data")
7273
@classmethod
73-
def validate_no_ambiguity(cls, val):
74+
def validate_no_ambiguity(
75+
cls, val: tuple[AbstractMonitorData, ...]
76+
) -> tuple[AbstractMonitorData, ...]:
7477
"""Ensure all :class:`AbstractMonitorData` entries in ``.data`` correspond to different
7578
monitors in ``.simulation``.
7679
"""

tidy3d/components/base_sim/simulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class AbstractSimulation(Box, ABC):
137137

138138
@model_validator(mode="before")
139139
@classmethod
140-
def _update_simulation(cls, data):
140+
def _update_simulation(cls, data: dict[str, Any]) -> dict[str, Any]:
141141
"""Update the simulation if it is an earlier version."""
142142
# dummy upgrade of version number
143143
# this should be overriden by each simulation class if needed
@@ -161,7 +161,7 @@ def _update_simulation(cls, data):
161161
_warn_traced_size = _warn_unsupported_traced_argument("size")
162162

163163
@model_validator(mode="after")
164-
def _structures_not_at_edges(self):
164+
def _structures_not_at_edges(self) -> Self:
165165
"""Warn if any structures lie at the simulation boundaries."""
166166

167167
if self.structures is None:
@@ -704,7 +704,7 @@ def from_scene(cls, scene: Scene, **kwargs: Any) -> AbstractSimulation:
704704
**kwargs,
705705
)
706706

707-
def plot_3d(self, width=800, height=800) -> None:
707+
def plot_3d(self, width: float = 800, height: float = 800) -> None:
708708
"""Render 3D plot of ``AbstractSimulation`` (in jupyter notebook only).
709709
Parameters
710710
----------

tidy3d/components/bc_placement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class StructureStructureInterface(AbstractBCPlacement):
4747

4848
@field_validator("structures")
4949
@classmethod
50-
def unique_names(cls, val):
50+
def unique_names(cls, val: tuple[str, str]) -> tuple[str, str]:
5151
"""Error if the same structure is provided twice"""
5252
if val[0] == val[1]:
5353
raise SetupError(
@@ -71,7 +71,7 @@ class MediumMediumInterface(AbstractBCPlacement):
7171

7272
@field_validator("mediums")
7373
@classmethod
74-
def unique_names(cls, val):
74+
def unique_names(cls, val: tuple[str, str]) -> tuple[str, str]:
7575
"""Error if the same structure is provided twice"""
7676
if val[0] == val[1]:
7777
raise SetupError("The same medium is provided twice in 'MediumMediumInterface'.")

tidy3d/components/beam.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Literal, Optional, Union
88

99
import autograd.numpy as np
10+
from numpy.typing import NDArray
1011
from pydantic import Field, PositiveFloat
1112

1213
from tidy3d.constants import C_0, ETA_0, HERTZ, MICROMETER, RADIAN
@@ -131,7 +132,9 @@ def field_data(self) -> FieldData:
131132

132133
return data_raw.updated_copy(**fields_norm)
133134

134-
def _field_data_on_grid(self, grid: Grid, background_n: np.ndarray, colocate=True) -> dict:
135+
def _field_data_on_grid(
136+
self, grid: Grid, background_n: NDArray, colocate: bool = True
137+
) -> dict[str, ScalarFieldDataArray]:
135138
"""Compute the field data for each field component on a grid for the beam.
136139
A dictionary of the scalar field data arrays is returned, not yet packaged as ``FieldData``.
137140
"""
@@ -165,14 +168,14 @@ def _field_data_on_grid(self, grid: Grid, background_n: np.ndarray, colocate=Tru
165168
return scalar_fields
166169

167170
@abstractmethod
168-
def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray:
171+
def scalar_field(self, points: NDArray, background_n: float) -> NDArray:
169172
"""Scalar field corresponding to the analytic beam in coordinate system such that the
170173
propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is
171174
computed on an unstructured array ``points`` of shape ``(3, ...)``."""
172175

173176
def analytic_beam_z_normal(
174-
self, points: np.ndarray, background_n: float, field: Literal["E", "H"]
175-
) -> np.ndarray:
177+
self, points: NDArray, background_n: float, field: Literal["E", "H"]
178+
) -> NDArray:
176179
"""Analytic beam with all the beam parameters but assuming ``z`` as the normal axis."""
177180

178181
# Add a frequency dimension to points
@@ -212,12 +215,12 @@ def analytic_beam_z_normal(
212215

213216
def analytic_beam(
214217
self,
215-
x: np.ndarray,
216-
y: np.ndarray,
217-
z: np.ndarray,
218+
x: NDArray,
219+
y: NDArray,
220+
z: NDArray,
218221
background_n: float,
219222
field: Literal["E", "H"],
220-
) -> np.ndarray:
223+
) -> NDArray:
221224
"""Sample the analytic beam fields on a cartesian grid of points in x, y, z."""
222225

223226
# Make a meshgrid
@@ -241,15 +244,13 @@ def analytic_beam(
241244
# Reshape to (3, Nx, Ny, Nz, num_freqs)
242245
return np.reshape(field_vals, (3, Nx, Ny, Nz, len(self.freqs)))
243246

244-
def _rotate_points_z(self, points: np.ndarray, background_n: np.ndarray) -> np.ndarray:
247+
def _rotate_points_z(self, points: NDArray, background_n: NDArray) -> NDArray:
245248
"""Rotate points to new coordinates where z is the propagation axis."""
246249
points_prop_z = self.rotate_points(points, [0, 0, 1], -self.angle_phi)
247250
points_prop_z = self.rotate_points(points_prop_z, [0, 1, 0], -self.angle_theta)
248251
return points_prop_z
249252

250-
def _inverse_rotate_field_vals_z(
251-
self, field_vals: np.ndarray, background_n: np.ndarray
252-
) -> np.ndarray:
253+
def _inverse_rotate_field_vals_z(self, field_vals: NDArray, background_n: NDArray) -> NDArray:
253254
"""Rotate field values from coordinates where z is the propagation axis to angled
254255
coordinates."""
255256
field_vals = self.rotate_points(field_vals, [0, 1, 0], self.angle_theta)
@@ -288,18 +289,18 @@ class PlaneWaveBeamProfile(BeamProfile):
288289
)
289290

290291
@property
291-
def _angle_theta_frequency(self):
292+
def _angle_theta_frequency(self) -> float:
292293
if not self.angle_theta_frequency:
293294
return np.mean(self.freqs)
294295
return self.angle_theta_frequency
295296

296-
def in_plane_k(self, background_n: float):
297+
def in_plane_k(self, background_n: float) -> list[float]:
297298
"""In-plane wave vector. Only the real part is taken so the beam has no in-plane decay."""
298299
k0 = 2 * np.pi * self._angle_theta_frequency / C_0 * background_n
299300
k_in_plane = k0.real * np.sin(self.angle_theta)
300301
return [k_in_plane * np.cos(self.angle_phi), k_in_plane * np.sin(self.angle_phi)]
301302

302-
def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray:
303+
def scalar_field(self, points: NDArray, background_n: float) -> NDArray:
303304
"""Scalar field for plane wave.
304305
Scalar field corresponding to the analytic beam in coordinate system such that the
305306
propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is
@@ -314,14 +315,14 @@ def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray:
314315
kz *= np.cos(self.angle_theta)
315316
return np.exp(1j * points[2] * kz)
316317

317-
def _angle_theta_actual(self, background_n: np.ndarray) -> np.ndarray:
318+
def _angle_theta_actual(self, background_n: NDArray) -> NDArray:
318319
"""Compute the frequency-dependent actual propagation angle theta."""
319320
k0 = 2 * np.pi * np.array(self.freqs) / C_0 * background_n
320321
kx, ky = self.in_plane_k(background_n)
321322
k_perp = np.sqrt(kx**2 + ky**2)
322323
return np.real(np.arcsin(k_perp / k0)) * np.sign(self.angle_theta)
323324

324-
def _rotate_points_z(self, points: np.ndarray, background_n: np.ndarray) -> np.ndarray:
325+
def _rotate_points_z(self, points: NDArray, background_n: NDArray) -> NDArray:
325326
"""Rotate points to new coordinates where z is the propagation axis."""
326327
if self.as_fixed_angle_source:
327328
# For fixed-angle, we do not rotate the points
@@ -335,9 +336,7 @@ def _rotate_points_z(self, points: np.ndarray, background_n: np.ndarray) -> np.n
335336
return points
336337
return super()._rotate_points_z(points, background_n)
337338

338-
def _inverse_rotate_field_vals_z(
339-
self, field_vals: np.ndarray, background_n: np.ndarray
340-
) -> np.ndarray:
339+
def _inverse_rotate_field_vals_z(self, field_vals: NDArray, background_n: NDArray) -> NDArray:
341340
"""Rotate field values from coordinates where z is the propagation axis to angled
342341
coordinates. Special handling is needed if fixed in-plane k wave."""
343342
if isinstance(self.angular_spec, FixedInPlaneKSpec):
@@ -378,9 +377,7 @@ class GaussianBeamProfile(BeamProfile):
378377
units=MICROMETER,
379378
)
380379

381-
def beam_params(
382-
self, z: np.ndarray, k0: np.ndarray
383-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
380+
def beam_params(self, z: NDArray, k0: NDArray) -> tuple[NDArray, NDArray, NDArray]:
384381
"""Compute the parameters needed to evaluate a Gaussian beam at z.
385382
386383
Parameters
@@ -402,7 +399,7 @@ def beam_params(
402399
psi_g = np.arctan((z + z_0) / z_r) - np.arctan(z_0 / z_r)
403400
return w_z, inv_r_z, psi_g
404401

405-
def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray:
402+
def scalar_field(self, points: NDArray, background_n: float) -> NDArray:
406403
"""Scalar field for Gaussian beam.
407404
Scalar field corresponding to the analytic beam in coordinate system such that the
408405
propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is
@@ -446,9 +443,7 @@ class AstigmaticGaussianBeamProfile(BeamProfile):
446443
units=MICROMETER,
447444
)
448445

449-
def beam_params(
450-
self, z: np.ndarray, k0: np.ndarray
451-
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
446+
def beam_params(self, z: NDArray, k0: NDArray) -> tuple[NDArray, NDArray, NDArray, NDArray]:
452447
"""Compute the parameters needed to evaluate an astigmatic Gaussian beam at z.
453448
454449
Parameters
@@ -475,7 +470,7 @@ def beam_params(
475470

476471
return w_0, w_z, inv_r_z, psi_g
477472

478-
def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray:
473+
def scalar_field(self, points: NDArray, background_n: float) -> NDArray:
479474
"""
480475
Scalar field for astigmatic Gaussian beam.
481476
Scalar field corresponding to the analytic beam in coordinate system such that the

tidy3d/components/data/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6-
from typing import Any, Callable, Literal, Optional, Self, Union, get_args
6+
from typing import Any, Callable, Literal, Optional, Union, get_args
77

88
import numpy as np
99
import xarray as xr
1010
from numpy.typing import ArrayLike
1111
from pydantic import Field
1212

13+
from tidy3d.compat import Self
1314
from tidy3d.components.base import Tidy3dBaseModel
1415
from tidy3d.components.types import Axis, FreqArray, xyz
1516
from tidy3d.constants import C_0, PICOSECOND_PER_NANOMETER_PER_KILOMETER, UnitScaling

tidy3d/components/data/monitor_data.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from abc import ABC
88
from math import isclose
99
from os import PathLike
10-
from typing import Any, Callable, Literal, Optional, Self, SupportsComplex, Union, get_args
10+
from typing import Any, Callable, Literal, Optional, SupportsComplex, Union, get_args
1111

1212
import autograd.numpy as np
1313
import xarray as xr
1414
from numpy.typing import NDArray
1515
from pandas import DataFrame
1616
from pydantic import Field, model_validator
1717

18+
from tidy3d.compat import Self
1819
from tidy3d.components.base import cached_property
1920
from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData
2021
from tidy3d.components.grid.grid import Coords, Grid
@@ -265,7 +266,7 @@ def symmetry_expanded(self) -> Self:
265266
return self.updated_copy(**self._symmetry_update_dict, deep=False, validate=False)
266267

267268
@property
268-
def symmetry_expanded_copy(self) -> AbstractFieldData:
269+
def symmetry_expanded_copy(self) -> Self:
269270
"""Create a copy of the :class:`.AbstractFieldData` with fields expanded based on symmetry.
270271
271272
Returns

tidy3d/components/data/unstructured/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numbers
66
from abc import ABC, abstractmethod
77
from os import PathLike
8-
from typing import TYPE_CHECKING, Any, Literal, Optional, Self, Union
8+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
99

1010
import numpy as np
1111
from numpy.typing import DTypeLike, NDArray
@@ -14,6 +14,7 @@
1414
from vtkmodules.vtkCommonCore import vtkPoints
1515
from xarray import DataArray as XrDataArray
1616

17+
from tidy3d.compat import Self
1718
from tidy3d.components.base import cached_property
1819
from tidy3d.components.data.data_array import (
1920
DATA_ARRAY_MAP,

tidy3d/components/data/unstructured/triangular.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Any, Literal, Optional, Self, Union
5+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
66

77
import numpy as np
88
from pydantic import Field, PositiveInt
99
from xarray import DataArray
1010
from xarray import DataArray as XrDataArray
1111

12+
from tidy3d.compat import Self
13+
1214
try:
1315
from matplotlib import pyplot as plt
1416
from matplotlib.tri import Triangulation

0 commit comments

Comments
 (0)