Skip to content

Commit 3a00af8

Browse files
committed
Move warning about ml mismatch until after the header is printed
1 parent af487e2 commit 3a00af8

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/koopmans/workflows/_workflow.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,6 @@ def __init__(self, atoms: Atoms,
186186
with open(self.ml.model_file, 'rb') as f:
187187
self.ml_model = dill.load(f)
188188
assert isinstance(self.ml_model, AbstractMLModel)
189-
if self.ml_model.estimator_type != self.ml.estimator:
190-
utils.warn(f'The estimator type of the loaded ML model ({self.ml_model.estimator_type}) does not match '
191-
f'the estimator type specified in the Workflow settings ({self.ml.estimator}). Overriding...')
192-
self.ml.estimator = self.ml_model.estimator_type
193-
if self.ml_model.descriptor_type != self.ml.descriptor:
194-
utils.warn(f'The descriptor type of the loaded ML model ({self.ml_model.descriptor_type}) does not match '
195-
f'the descriptor type specified in the Workflow settings ({self.ml.descriptor}). Overriding...')
196-
self.ml.descriptor = self.ml_model.descriptor_type
197189

198190
else:
199191
self.ml_model = None
@@ -584,6 +576,14 @@ def _run_sanity_checks(self):
584576
raise ValueError("You have requested to train or predict with a machine-learning model, but no model "
585577
"is attached to this workflow. Either set ml:train or predict to True when initializing "
586578
"the workflow, or directly add a model to the workflow's ml_model attribute")
579+
if self.ml_model.estimator_type != self.ml.estimator:
580+
utils.warn(f'The estimator type of the loaded ML model ({self.ml_model.estimator_type}) does not match '
581+
f'the estimator type specified in the Workflow settings ({self.ml.estimator}). Overriding...')
582+
self.ml.estimator = self.ml_model.estimator_type
583+
if self.ml_model.descriptor_type != self.ml.descriptor:
584+
utils.warn(f'The descriptor type of the loaded ML model ({self.ml_model.descriptor_type}) does not match '
585+
f'the descriptor type specified in the Workflow settings ({self.ml.descriptor}). Overriding...')
586+
self.ml.descriptor = self.ml_model.descriptor_type
587587
if [self.ml.predict, self.ml.train, self.ml.test].count(True) > 1:
588588
raise ValueError(
589589
'Training, testing, and using the ML model are mutually exclusive; change `ml:predict` '

0 commit comments

Comments
 (0)