Skip to content

Commit 6be2a94

Browse files
committed
add option to return raw model in mace_mp
1 parent 6b7d9c9 commit 6be2a94

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

mace/calculators/foundations_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def mace_mp(
6868
damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"]
6969
dispersion_xc: str = "pbe",
7070
dispersion_cutoff: float = 40.0 * units.Bohr,
71+
return_raw_model: bool = False,
7172
**kwargs,
7273
) -> MACECalculator:
7374
"""
@@ -93,6 +94,7 @@ def mace_mp(
9394
damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ).
9495
dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections.
9596
dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections.
97+
return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False.
9698
**kwargs: Passed to MACECalculator and TorchDFTD3Calculator.
9799
98100
Returns:
@@ -114,6 +116,9 @@ def mace_mp(
114116
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
115117
)
116118

119+
if return_raw_model:
120+
return torch.load(model_path, map_location=device)
121+
117122
mace_calc = MACECalculator(
118123
model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs
119124
)
@@ -221,6 +226,7 @@ def mace_off(
221226
def mace_anicc(
222227
device: str = "cuda",
223228
model_path: str = None,
229+
return_raw_model: bool = False,
224230
) -> MACECalculator:
225231
"""
226232
Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O).
@@ -236,6 +242,8 @@ def mace_anicc(
236242
print(
237243
"Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322"
238244
)
245+
if return_raw_model:
246+
return torch.load(model_path, map_location=device)
239247
return MACECalculator(
240248
model_paths=model_path, device=device, default_dtype="float64"
241249
)

0 commit comments

Comments
 (0)