Skip to content

Commit e953c00

Browse files
dbochkov-flexcomputemomchil-flex
authored andcommitted
unified validation check for missing dependency fields
1 parent efb87ae commit e953c00

28 files changed

+140
-74
lines changed

tests/test_components/test_simulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,11 +505,11 @@ def test_validate_components_none():
505505
assert SIM._source_homogeneous_isotropic(val=None, values=SIM.dict()) is None
506506

507507

508-
def test_sources_edge_case_validation():
508+
def test_sources_edge_case_validation(log_capture):
509509
values = SIM.dict()
510510
values.pop("sources")
511-
with pytest.raises(ValidationError):
512-
SIM._warn_monitor_simulation_frequency_range(val="test", values=values)
511+
SIM._warn_monitor_simulation_frequency_range(val="test", values=values)
512+
assert_log_level(log_capture, "WARNING")
513513

514514

515515
def test_validate_size_run_time(monkeypatch):

tidy3d/components/apodization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pydantic.v1 as pd
44
import numpy as np
55

6-
from .base import Tidy3dBaseModel
6+
from .base import Tidy3dBaseModel, skip_if_fields_missing
77
from ..constants import SECOND
88
from ..exceptions import SetupError
99
from .types import ArrayFloat1D, Ax
@@ -40,6 +40,7 @@ class ApodizationSpec(Tidy3dBaseModel):
4040
)
4141

4242
@pd.validator("end", always=True, allow_reuse=True)
43+
@skip_if_fields_missing(["start"])
4344
def end_greater_than_start(cls, val, values):
4445
"""Ensure end is greater than or equal to start."""
4546
start = values.get("start")
@@ -48,6 +49,7 @@ def end_greater_than_start(cls, val, values):
4849
return val
4950

5051
@pd.validator("width", always=True, allow_reuse=True)
52+
@skip_if_fields_missing(["start", "end"])
5153
def width_provided(cls, val, values):
5254
"""Check that width is provided if either start or end apodization is requested."""
5355
start = values.get("start")

tidy3d/components/base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,28 @@ def _get_valid_extension(fname: str) -> str:
8686
)
8787

8888

89+
def skip_if_fields_missing(fields: List[str]):
90+
"""Decorate ``validator`` to check that other fields have passed validation."""
91+
92+
def actual_decorator(validator):
93+
@wraps(validator)
94+
def _validator(cls, val, values):
95+
"""New validator function."""
96+
for field in fields:
97+
if field not in values:
98+
log.warning(
99+
f"Could not execute validator '{validator.__name__}' because field "
100+
f"'{field}' failed validation."
101+
)
102+
return val
103+
104+
return validator(cls, val, values)
105+
106+
return _validator
107+
108+
return actual_decorator
109+
110+
89111
class Tidy3dBaseModel(pydantic.BaseModel):
90112
"""Base pydantic model that all Tidy3d components inherit from.
91113
Defines configuration for handling data structures

tidy3d/components/base_sim/data/sim_data.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ..simulation import AbstractSimulation
1313
from ...data.dataset import UnstructuredGridDatasetType
1414
from ...base import Tidy3dBaseModel
15+
from ...base import skip_if_fields_missing
1516
from ...types import FieldVal
1617
from ....exceptions import DataError, Tidy3dKeyError, ValidationError
1718

@@ -51,13 +52,13 @@ def monitor_data(self) -> Dict[str, AbstractMonitorData]:
5152
return {monitor_data.monitor.name: monitor_data for monitor_data in self.data}
5253

5354
@pd.validator("data", always=True)
55+
@skip_if_fields_missing(["simulation"])
5456
def data_monitors_match_sim(cls, val, values):
5557
"""Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in
5658
``.simulation``.
5759
"""
5860
sim = values.get("simulation")
59-
if sim is None:
60-
raise ValidationError("'.simulation' failed validation, can't validate data.")
61+
6162
for mnt_data in val:
6263
try:
6364
monitor_name = mnt_data.monitor.name
@@ -70,14 +71,11 @@ def data_monitors_match_sim(cls, val, values):
7071
return val
7172

7273
@pd.validator("data", always=True)
74+
@skip_if_fields_missing(["simulation"])
7375
def validate_no_ambiguity(cls, val, values):
7476
"""Ensure all :class:`AbstractMonitorData` entries in ``.data`` correspond to different
7577
monitors in ``.simulation``.
7678
"""
77-
sim = values.get("simulation")
78-
if sim is None:
79-
raise ValidationError("'.simulation' failed validation, can't validate data.")
80-
8179
names = [mnt_data.monitor.name for mnt_data in val]
8280

8381
if len(set(names)) != len(names):

tidy3d/components/base_sim/simulation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .monitor import AbstractMonitor
1111

12-
from ..base import cached_property
12+
from ..base import cached_property, skip_if_fields_missing
1313
from ..validators import assert_unique_names, assert_objects_in_sim_bounds
1414
from ..geometry.base import Box
1515
from ..types import Ax, Bound, Axis, Symmetry, TYPE_TAG_STR
@@ -97,6 +97,7 @@ class AbstractSimulation(Box, ABC):
9797
_structures_in_bounds = assert_objects_in_sim_bounds("structures", error=False)
9898

9999
@pd.validator("structures", always=True)
100+
@skip_if_fields_missing(["size", "center"])
100101
def _structures_not_at_edges(cls, val, values):
101102
"""Warn if any structures lie at the simulation boundaries."""
102103

tidy3d/components/data/dataset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ..viz import equal_aspect, add_ax_if_none, plot_params_grid
2222
from ..base import Tidy3dBaseModel, cached_property
23+
from ..base import skip_if_fields_missing
2324
from ..types import Axis, Bound, ArrayLike, Ax, Coordinate, Literal
2425
from ..types import vtk, requires_vtk
2526
from ...exceptions import DataError, ValidationError, Tidy3dNotImplementedError
@@ -524,13 +525,12 @@ def match_cells_to_vtk_type(cls, val):
524525
return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords)
525526

526527
@pd.validator("values", always=True)
528+
@skip_if_fields_missing(["points"])
527529
def number_of_values_matches_points(cls, val, values):
528530
"""Check that the number of data values matches the number of grid points."""
529531
num_values = len(val)
530532

531533
points = values.get("points")
532-
if points is None:
533-
raise ValidationError("Cannot validate '.values' because '.points' failed validation.")
534534
num_points = len(points)
535535

536536
if num_points != num_values:
@@ -565,15 +565,14 @@ def cells_right_type(cls, val):
565565
return val
566566

567567
@pd.validator("cells", always=True)
568+
@skip_if_fields_missing(["points"])
568569
def check_cell_vertex_range(cls, val, values):
569570
"""Check that cell connections use only defined points."""
570571
all_point_indices_used = val.data.ravel()
571572
min_index_used = np.min(all_point_indices_used)
572573
max_index_used = np.max(all_point_indices_used)
573574

574575
points = values.get("points")
575-
if points is None:
576-
raise ValidationError("Cannot validate '.values' because '.points' failed validation.")
577576
num_points = len(points)
578577

579578
if max_index_used != num_points - 1 or min_index_used != 0:

tidy3d/components/data/monitor_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .data_array import FreqDataArray, TimeDataArray, FreqModeDataArray
2121
from .dataset import Dataset, AbstractFieldDataset, ElectromagneticFieldDataset
2222
from .dataset import FieldDataset, FieldTimeDataset, ModeSolverDataset, PermittivityDataset
23-
from ..base import TYPE_TAG_STR, cached_property
23+
from ..base import TYPE_TAG_STR, cached_property, skip_if_fields_missing
2424
from ..types import Coordinate, Symmetry, ArrayFloat1D, ArrayFloat2D, Size, Numpy, TrackFreq
2525
from ..types import EpsSpecType, Literal
2626
from ..grid.grid import Grid, Coords
@@ -926,6 +926,7 @@ class ModeSolverData(ModeSolverDataset, ElectromagneticFieldData):
926926
)
927927

928928
@pd.validator("eps_spec", always=True)
929+
@skip_if_fields_missing(["monitor"])
929930
def eps_spec_match_mode_spec(cls, val, values):
930931
"""Raise validation error if frequencies in eps_spec does not match frequency list"""
931932
if val:

tidy3d/components/field_projection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .monitor import FieldProjectionCartesianMonitor, FieldProjectionKSpaceMonitor
2020
from .types import Direction, Coordinate, ArrayComplex4D
2121
from .medium import MediumType
22-
from .base import Tidy3dBaseModel, cached_property
22+
from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
2323
from ..exceptions import SetupError
2424
from ..constants import C_0, MICROMETER, ETA_0, EPSILON_0, MU_0
2525
from ..log import get_logging_console
@@ -72,6 +72,7 @@ class FieldProjector(Tidy3dBaseModel):
7272
)
7373

7474
@pydantic.validator("origin", always=True)
75+
@skip_if_fields_missing(["surfaces"])
7576
def set_origin(cls, val, values):
7677
"""Sets .origin as the average of centers of all surface monitors if not provided."""
7778
if val is None:

tidy3d/components/geometry/polyslab.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from matplotlib import path
1212

1313
from ..base import cached_property
14+
from ..base import skip_if_fields_missing
1415
from ..types import Axis, Bound, PlanePosition, ArrayFloat2D, Coordinate
1516
from ..types import MatrixReal4x4, Shapely, trimesh
1617
from ...log import log
@@ -105,6 +106,7 @@ def correct_shape(cls, val):
105106
return val
106107

107108
@pydantic.validator("vertices", always=True)
109+
@skip_if_fields_missing(["dilation"])
108110
def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values):
109111
"""At the reference plane, check if the polygon is self-intersecting.
110112
@@ -154,6 +156,7 @@ def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values):
154156
return val
155157

156158
@pydantic.validator("vertices", always=True)
159+
@skip_if_fields_missing(["sidewall_angle", "dilation", "slab_bounds", "reference_plane"])
157160
def no_self_intersecting_polygon_during_extrusion(cls, val, values):
158161
"""In this simple polyslab, we don't support self-intersecting polygons yet, meaning that
159162
any normal cross section of the PolySlab cannot be self-intersecting. This part checks
@@ -168,8 +171,6 @@ def no_self_intersecting_polygon_during_extrusion(cls, val, values):
168171
To detect this, we sample _N_SAMPLE_POLYGON_INTERSECT cross sections to see if any creation
169172
of polygons/holes, and changes in vertices number.
170173
"""
171-
if "sidewall_angle" not in values:
172-
raise ValidationError("'sidewall_angle' failed validation.")
173174

174175
# no need to valiate anything here
175176
if isclose(values["sidewall_angle"], 0):

tidy3d/components/geometry/primitives.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import shapely
1010

11-
from ..base import cached_property
11+
from ..base import cached_property, skip_if_fields_missing
1212
from ..types import Axis, Bound, Coordinate, MatrixReal4x4, Shapely, trimesh
1313
from ...exceptions import SetupError, ValidationError
1414
from ...constants import MICROMETER, LARGE_NUMBER
@@ -191,6 +191,7 @@ class Cylinder(base.Centered, base.Circular, base.Planar):
191191
)
192192

193193
@pydantic.validator("length", always=True)
194+
@skip_if_fields_missing(["sidewall_angle", "reference_plane"])
194195
def _only_middle_for_infinite_length_slanted_cylinder(cls, val, values):
195196
"""For a slanted cylinder of infinite length, ``reference_plane`` can only
196197
be ``middle``; otherwise, the radius at ``center`` is either td.inf or 0.

tidy3d/components/heat/data/monitor_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pydantic.v1 as pd
88

99
from ..monitor import TemperatureMonitor, HeatMonitorType
10+
from ...base import skip_if_fields_missing
1011
from ...base_sim.data.monitor_data import AbstractMonitorData
1112
from ...data.data_array import SpatialDataArray
1213
from ...data.dataset import TriangularGridDataset, TetrahedralGridDataset
@@ -74,6 +75,7 @@ class TemperatureData(HeatMonitorData):
7475
)
7576

7677
@pd.validator("temperature", always=True)
78+
@skip_if_fields_missing(["monitor"])
7779
def warn_no_data(cls, val, values):
7880
"""Warn if no data provided."""
7981

tidy3d/components/heat/grid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Union, Tuple
55
import pydantic.v1 as pd
66

7-
from ..base import Tidy3dBaseModel
7+
from ..base import Tidy3dBaseModel, skip_if_fields_missing
88
from ...constants import MICROMETER
99
from ...exceptions import ValidationError
1010

@@ -107,6 +107,7 @@ class DistanceUnstructuredGrid(Tidy3dBaseModel):
107107
)
108108

109109
@pd.validator("distance_bulk", always=True)
110+
@skip_if_fields_missing(["distance_interface"])
110111
def names_exist_bcs(cls, val, values):
111112
"""Error if distance_bulk is less than distance_interface"""
112113
distance_interface = values.get("distance_interface")

tidy3d/components/heat/simulation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .viz import plot_params_heat_bc, plot_params_heat_source, HEAT_SOURCE_CMAP
1616

1717
from ..base_sim.simulation import AbstractSimulation
18-
from ..base import cached_property
18+
from ..base import cached_property, skip_if_fields_missing
1919
from ..types import Ax, Shapely, TYPE_TAG_STR, ScalarSymmetry, Bound
2020
from ..viz import add_ax_if_none, equal_aspect, PlotParams
2121
from ..structure import Structure
@@ -139,6 +139,7 @@ def check_zero_dim_domain(cls, val, values):
139139
return val
140140

141141
@pd.validator("boundary_spec", always=True)
142+
@skip_if_fields_missing(["structures", "medium"])
142143
def names_exist_bcs(cls, val, values):
143144
"""Error if boundary conditions point to non-existing structures/media."""
144145

@@ -175,6 +176,7 @@ def names_exist_bcs(cls, val, values):
175176
return val
176177

177178
@pd.validator("grid_spec", always=True)
179+
@skip_if_fields_missing(["structures"])
178180
def names_exist_grid_spec(cls, val, values):
179181
"""Warn if UniformUnstructuredGrid points at a non-existing structure."""
180182

@@ -191,6 +193,7 @@ def names_exist_grid_spec(cls, val, values):
191193
return val
192194

193195
@pd.validator("sources", always=True)
196+
@skip_if_fields_missing(["structures"])
194197
def names_exist_sources(cls, val, values):
195198
"""Error if a heat source point to non-existing structures."""
196199
structures = values.get("structures")

0 commit comments

Comments
 (0)