Skip to content

Commit c34a23e

Browse files
caseyflexmomchil-flex
authored andcommitted
Added max_num_poles at least min_num_poles validator to FastDispersionFitter
1 parent 0857821 commit c34a23e

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

tidy3d/plugins/dispersion/fit_fast.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,12 @@ class FastFitterData(AdvancedFastFitterParam):
149149
title="eps_inf",
150150
description="Value of ``eps_inf``.",
151151
)
152-
poles: ArrayComplex1D = Field(
152+
poles: Optional[ArrayComplex1D] = Field(
153153
None, title="Pole frequencies in eV", description="Pole frequencies in eV"
154154
)
155-
residues: ArrayComplex1D = Field(None, title="Residues in eV", description="Residues in eV")
155+
residues: Optional[ArrayComplex1D] = Field(
156+
None, title="Residues in eV", description="Residues in eV"
157+
)
156158

157159
passivity_optimized: Optional[bool] = Field(
158160
False,
@@ -188,7 +190,11 @@ def _generate_initial_poles(cls, val, values):
188190
"""Generate initial poles."""
189191
if val is not None:
190192
return val
191-
if values["logspacing"] is None or values["smooth"] is None:
193+
if (
194+
values.get("logspacing") is None
195+
or values.get("smooth") is None
196+
or values.get("num_poles") is None
197+
):
192198
return None
193199
omega = values["omega"]
194200
num_poles = values["num_poles"]
@@ -211,7 +217,7 @@ def _generate_initial_residues(cls, val, values):
211217
"""Generate initial residues."""
212218
if val is not None:
213219
return val
214-
poles = values["poles"]
220+
poles = values.get("poles")
215221
if poles is None:
216222
return None
217223
return np.zeros(len(poles))
@@ -675,6 +681,11 @@ def fit(
675681
Best fitting result: (dispersive medium, weighted RMS error).
676682
"""
677683

684+
if max_num_poles < min_num_poles:
685+
raise ValidationError(
686+
"Dispersion fitter cannot have 'max_num_poles' less than 'min_num_poles'."
687+
)
688+
678689
omega = PoleResidue.angular_freq_to_eV(PoleResidue.Hz_to_angular_freq(self.freqs[::-1]))
679690
eps = self.eps_data[::-1]
680691

0 commit comments

Comments
 (0)