77from typing import Literal , Optional , Union
88
99import autograd .numpy as np
10+ from numpy .typing import NDArray
1011from pydantic import Field , PositiveFloat
1112
1213from tidy3d .constants import C_0 , ETA_0 , HERTZ , MICROMETER , RADIAN
@@ -131,7 +132,9 @@ def field_data(self) -> FieldData:
131132
132133 return data_raw .updated_copy (** fields_norm )
133134
134- def _field_data_on_grid (self , grid : Grid , background_n : np .ndarray , colocate = True ) -> dict :
135+ def _field_data_on_grid (
136+ self , grid : Grid , background_n : NDArray , colocate : bool = True
137+ ) -> dict [str , ScalarFieldDataArray ]:
135138 """Compute the field data for each field component on a grid for the beam.
136139 A dictionary of the scalar field data arrays is returned, not yet packaged as ``FieldData``.
137140 """
@@ -165,14 +168,14 @@ def _field_data_on_grid(self, grid: Grid, background_n: np.ndarray, colocate=Tru
165168 return scalar_fields
166169
167170 @abstractmethod
168- def scalar_field (self , points : np . ndarray , background_n : float ) -> np . ndarray :
171+ def scalar_field (self , points : NDArray , background_n : float ) -> NDArray :
169172 """Scalar field corresponding to the analytic beam in coordinate system such that the
170173 propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is
171174 computed on an unstructured array ``points`` of shape ``(3, ...)``."""
172175
173176 def analytic_beam_z_normal (
174- self , points : np . ndarray , background_n : float , field : Literal ["E" , "H" ]
175- ) -> np . ndarray :
177+ self , points : NDArray , background_n : float , field : Literal ["E" , "H" ]
178+ ) -> NDArray :
176179 """Analytic beam with all the beam parameters but assuming ``z`` as the normal axis."""
177180
178181 # Add a frequency dimension to points
@@ -212,12 +215,12 @@ def analytic_beam_z_normal(
212215
213216 def analytic_beam (
214217 self ,
215- x : np . ndarray ,
216- y : np . ndarray ,
217- z : np . ndarray ,
218+ x : NDArray ,
219+ y : NDArray ,
220+ z : NDArray ,
218221 background_n : float ,
219222 field : Literal ["E" , "H" ],
220- ) -> np . ndarray :
223+ ) -> NDArray :
221224 """Sample the analytic beam fields on a cartesian grid of points in x, y, z."""
222225
223226 # Make a meshgrid
@@ -241,15 +244,13 @@ def analytic_beam(
241244 # Reshape to (3, Nx, Ny, Nz, num_freqs)
242245 return np .reshape (field_vals , (3 , Nx , Ny , Nz , len (self .freqs )))
243246
244- def _rotate_points_z (self , points : np . ndarray , background_n : np . ndarray ) -> np . ndarray :
247+ def _rotate_points_z (self , points : NDArray , background_n : NDArray ) -> NDArray :
245248 """Rotate points to new coordinates where z is the propagation axis."""
246249 points_prop_z = self .rotate_points (points , [0 , 0 , 1 ], - self .angle_phi )
247250 points_prop_z = self .rotate_points (points_prop_z , [0 , 1 , 0 ], - self .angle_theta )
248251 return points_prop_z
249252
250- def _inverse_rotate_field_vals_z (
251- self , field_vals : np .ndarray , background_n : np .ndarray
252- ) -> np .ndarray :
253+ def _inverse_rotate_field_vals_z (self , field_vals : NDArray , background_n : NDArray ) -> NDArray :
253254 """Rotate field values from coordinates where z is the propagation axis to angled
254255 coordinates."""
255256 field_vals = self .rotate_points (field_vals , [0 , 1 , 0 ], self .angle_theta )
@@ -288,18 +289,18 @@ class PlaneWaveBeamProfile(BeamProfile):
288289 )
289290
290291 @property
291- def _angle_theta_frequency (self ):
292+ def _angle_theta_frequency (self ) -> float :
292293 if not self .angle_theta_frequency :
293294 return np .mean (self .freqs )
294295 return self .angle_theta_frequency
295296
296- def in_plane_k (self , background_n : float ):
297+ def in_plane_k (self , background_n : float ) -> list [ float ] :
297298 """In-plane wave vector. Only the real part is taken so the beam has no in-plane decay."""
298299 k0 = 2 * np .pi * self ._angle_theta_frequency / C_0 * background_n
299300 k_in_plane = k0 .real * np .sin (self .angle_theta )
300301 return [k_in_plane * np .cos (self .angle_phi ), k_in_plane * np .sin (self .angle_phi )]
301302
302- def scalar_field (self , points : np . ndarray , background_n : float ) -> np . ndarray :
303+ def scalar_field (self , points : NDArray , background_n : float ) -> NDArray :
303304 """Scalar field for plane wave.
304305 Scalar field corresponding to the analytic beam in coordinate system such that the
305306 propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is
@@ -314,14 +315,14 @@ def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray:
314315 kz *= np .cos (self .angle_theta )
315316 return np .exp (1j * points [2 ] * kz )
316317
317- def _angle_theta_actual (self , background_n : np . ndarray ) -> np . ndarray :
318+ def _angle_theta_actual (self , background_n : NDArray ) -> NDArray :
318319 """Compute the frequency-dependent actual propagation angle theta."""
319320 k0 = 2 * np .pi * np .array (self .freqs ) / C_0 * background_n
320321 kx , ky = self .in_plane_k (background_n )
321322 k_perp = np .sqrt (kx ** 2 + ky ** 2 )
322323 return np .real (np .arcsin (k_perp / k0 )) * np .sign (self .angle_theta )
323324
324- def _rotate_points_z (self , points : np . ndarray , background_n : np . ndarray ) -> np . ndarray :
325+ def _rotate_points_z (self , points : NDArray , background_n : NDArray ) -> NDArray :
325326 """Rotate points to new coordinates where z is the propagation axis."""
326327 if self .as_fixed_angle_source :
327328 # For fixed-angle, we do not rotate the points
@@ -335,9 +336,7 @@ def _rotate_points_z(self, points: np.ndarray, background_n: np.ndarray) -> np.n
335336 return points
336337 return super ()._rotate_points_z (points , background_n )
337338
338- def _inverse_rotate_field_vals_z (
339- self , field_vals : np .ndarray , background_n : np .ndarray
340- ) -> np .ndarray :
339+ def _inverse_rotate_field_vals_z (self , field_vals : NDArray , background_n : NDArray ) -> NDArray :
341340 """Rotate field values from coordinates where z is the propagation axis to angled
342341 coordinates. Special handling is needed if fixed in-plane k wave."""
343342 if isinstance (self .angular_spec , FixedInPlaneKSpec ):
@@ -378,9 +377,7 @@ class GaussianBeamProfile(BeamProfile):
378377 units = MICROMETER ,
379378 )
380379
381- def beam_params (
382- self , z : np .ndarray , k0 : np .ndarray
383- ) -> tuple [np .ndarray , np .ndarray , np .ndarray ]:
380+ def beam_params (self , z : NDArray , k0 : NDArray ) -> tuple [NDArray , NDArray , NDArray ]:
384381 """Compute the parameters needed to evaluate a Gaussian beam at z.
385382
386383 Parameters
@@ -402,7 +399,7 @@ def beam_params(
402399 psi_g = np .arctan ((z + z_0 ) / z_r ) - np .arctan (z_0 / z_r )
403400 return w_z , inv_r_z , psi_g
404401
405- def scalar_field (self , points : np . ndarray , background_n : float ) -> np . ndarray :
402+ def scalar_field (self , points : NDArray , background_n : float ) -> NDArray :
406403 """Scalar field for Gaussian beam.
407404 Scalar field corresponding to the analytic beam in coordinate system such that the
408405 propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is
@@ -446,9 +443,7 @@ class AstigmaticGaussianBeamProfile(BeamProfile):
446443 units = MICROMETER ,
447444 )
448445
449- def beam_params (
450- self , z : np .ndarray , k0 : np .ndarray
451- ) -> tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
446+ def beam_params (self , z : NDArray , k0 : NDArray ) -> tuple [NDArray , NDArray , NDArray , NDArray ]:
452447 """Compute the parameters needed to evaluate an astigmatic Gaussian beam at z.
453448
454449 Parameters
@@ -475,7 +470,7 @@ def beam_params(
475470
476471 return w_0 , w_z , inv_r_z , psi_g
477472
478- def scalar_field (self , points : np . ndarray , background_n : float ) -> np . ndarray :
473+ def scalar_field (self , points : NDArray , background_n : float ) -> NDArray :
479474 """
480475 Scalar field for astigmatic Gaussian beam.
481476 Scalar field corresponding to the analytic beam in coordinate system such that the
0 commit comments