@@ -68,6 +68,7 @@ def mace_mp(
68
68
damping : str = "bj" , # choices: ["zero", "bj", "zerom", "bjm"]
69
69
dispersion_xc : str = "pbe" ,
70
70
dispersion_cutoff : float = 40.0 * units .Bohr ,
71
+ return_raw_model : bool = False ,
71
72
** kwargs ,
72
73
) -> MACECalculator :
73
74
"""
@@ -93,6 +94,7 @@ def mace_mp(
93
94
damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ).
94
95
dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections.
95
96
dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections.
97
+ return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
96
98
**kwargs: Passed to MACECalculator and TorchDFTD3Calculator.
97
99
98
100
Returns:
@@ -114,6 +116,9 @@ def mace_mp(
114
116
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
115
117
)
116
118
119
+ if return_raw_model :
120
+ return torch .load (model_path , map_location = device )
121
+
117
122
mace_calc = MACECalculator (
118
123
model_paths = model_path , device = device , default_dtype = default_dtype , ** kwargs
119
124
)
@@ -221,6 +226,7 @@ def mace_off(
221
226
def mace_anicc (
222
227
device : str = "cuda" ,
223
228
model_path : str = None ,
229
+ return_raw_model : bool = False ,
224
230
) -> MACECalculator :
225
231
"""
226
232
Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O).
@@ -236,6 +242,8 @@ def mace_anicc(
236
242
print (
237
243
"Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322"
238
244
)
245
+ if return_raw_model :
246
+ return torch .load (model_path , map_location = device )
239
247
return MACECalculator (
240
248
model_paths = model_path , device = device , default_dtype = "float64"
241
249
)
0 commit comments