@@ -186,14 +186,6 @@ def __init__(self, atoms: Atoms,
186
186
with open (self .ml .model_file , 'rb' ) as f :
187
187
self .ml_model = dill .load (f )
188
188
assert isinstance (self .ml_model , AbstractMLModel )
189
- if self .ml_model .estimator_type != self .ml .estimator :
190
- utils .warn (f'The estimator type of the loaded ML model ({ self .ml_model .estimator_type } ) does not match '
191
- f'the estimator type specified in the Workflow settings ({ self .ml .estimator } ). Overriding...' )
192
- self .ml .estimator = self .ml_model .estimator_type
193
- if self .ml_model .descriptor_type != self .ml .descriptor :
194
- utils .warn (f'The descriptor type of the loaded ML model ({ self .ml_model .descriptor_type } ) does not match '
195
- f'the descriptor type specified in the Workflow settings ({ self .ml .descriptor } ). Overriding...' )
196
- self .ml .descriptor = self .ml_model .descriptor_type
197
189
198
190
else :
199
191
self .ml_model = None
@@ -584,6 +576,14 @@ def _run_sanity_checks(self):
584
576
raise ValueError ("You have requested to train or predict with a machine-learning model, but no model "
585
577
"is attached to this workflow. Either set ml:train or predict to True when initializing "
586
578
"the workflow, or directly add a model to the workflow's ml_model attribute" )
579
+ if self .ml_model .estimator_type != self .ml .estimator :
580
+ utils .warn (f'The estimator type of the loaded ML model ({ self .ml_model .estimator_type } ) does not match '
581
+ f'the estimator type specified in the Workflow settings ({ self .ml .estimator } ). Overriding...' )
582
+ self .ml .estimator = self .ml_model .estimator_type
583
+ if self .ml_model .descriptor_type != self .ml .descriptor :
584
+ utils .warn (f'The descriptor type of the loaded ML model ({ self .ml_model .descriptor_type } ) does not match '
585
+ f'the descriptor type specified in the Workflow settings ({ self .ml .descriptor } ). Overriding...' )
586
+ self .ml .descriptor = self .ml_model .descriptor_type
587
587
if [self .ml .predict , self .ml .train , self .ml .test ].count (True ) > 1 :
588
588
raise ValueError (
589
589
'Training, testing, and using the ML model are mutually exclusive; change `ml:predict` '
0 commit comments