Skip to content

Commit

Permalink
Add MatterSim calculator (#427)
Browse files Browse the repository at this point in the history
* add mattersim calculator fixes #425

* Update pyproject.toml

Co-authored-by: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com>

* Fix tests for optional MACE

* Allow model to be passed to mattersim

* Add mattersim tests

* Fix mattersim installation

* Test mattersim in workflow

* Update docs

* Fix mattersim skip

* Test mattersim on mac

---------

Co-authored-by: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com>
  • Loading branch information
alinelena and ElliottKasoar authored Mar 3, 2025
1 parent cf74c42 commit adba856
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 1 deletion.
11 changes: 11 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ jobs:
PYTEST_ADDOPTS: "--durations=0"
run: uv run pytest --cov janus_core --cov-append .

- name: Install updated e3nn dependencies
run: |
uv sync --extra mattersim
uv pip install --reinstall pynvml
- name: Run test suite
env:
# show timings of tests
PYTEST_ADDOPTS: "--durations=0"
run: uv run pytest tests/test_{mlip_calculators,single_point}.py

- name: Install dgl dependencies
run: |
uv sync --extra mace --extra m3gnet --extra alignn
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ jobs:
PYTEST_ADDOPTS: "--durations=0"
run: uv run pytest

- name: Install updated e3nn dependencies
run: |
uv sync --extra mattersim
uv pip install --reinstall pynvml
- name: Run test suite
env:
# show timings of tests
PYTEST_ADDOPTS: "--durations=0"
run: uv run pytest tests/test_{mlip_calculators,single_point}.py

- name: Install dgl dependencies
run: |
uv sync --extra mace --extra m3gnet --extra alignn
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Current and planned features include:
- NequIP
- DPA3
- Orb
- MatterSim
- [x] Single point calculations
- [x] Geometry optimisation
- [x] Molecular Dynamics
Expand Down
1 change: 1 addition & 0 deletions docs/source/getting_started/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,6 @@ Currently supported extras are:
- ``nequip``: `NequIP <https://github.com/mir-group/nequip>`_
- ``dpa3``: `DPA3 <https://github.com/deepmodeling/deepmd-kit/tree/dpa3-alpha>`_
- ``orb``: `Orb <https://github.com/orbital-materials/orb-models>`_
- ``mattersim``: `MatterSim <https://github.com/microsoft/mattersim>`_

``extras`` are also listed in `pyproject.toml <https://github.com/stfc/janus-core/blob/main/pyproject.toml>`_ under ``[project.optional-dependencies]``.
1 change: 1 addition & 0 deletions janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class CorrelationKwargs(TypedDict, total=True):
"nequip",
"dpa3",
"orb",
"mattersim",
]
Devices = Literal["cpu", "cuda", "mps", "xpu"]
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"]
Expand Down
21 changes: 21 additions & 0 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,27 @@ def choose_calculator(

calculator = ORBCalculator(model=model, device=device, **kwargs)

elif arch == "mattersim":
from mattersim import __version__
from mattersim.forcefield import MatterSimCalculator
from torch.nn import Module

# Default model
model_path = model_path if model_path else "mattersim-v1.0.0-5M"

if isinstance(model_path, Module):
potential = model_path
model_path = "loaded_Module"
else:
potential = None

if isinstance(model_path, Path):
model_path = str(model_path)

calculator = MatterSimCalculator(
potential=potential, load_path=model_path, device=device, **kwargs
)

else:
raise ValueError(
f"Unrecognized {arch=}. Suported architectures "
Expand Down
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ all = [
"janus-core[sevennet]",
]

# MLIP with updated e3nn
mattersim = [
"mattersim == 1.1.1",
]

# MLIPs with dgl dependency
alignn = [
"alignn == 2024.5.27",
Expand Down Expand Up @@ -216,6 +221,14 @@ conflicts = [
{ extra = "all" },
{ extra = "m3gnet" },
],
[
{ extra = "mattersim" },
{ extra = "mace" },
],
[
{ extra = "mattersim" },
{ extra = "all" },
],
]

[tool.uv.sources]
Expand Down
10 changes: 9 additions & 1 deletion tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
("chgnet", "cpu", {"model": CHGNET_MODEL}),
("dpa3", "cpu", {"model_path": DPA3_PATH}),
("dpa3", "cpu", {"model": DPA3_PATH}),
("mattersim", "cpu", {}),
("mattersim", "cpu", {"model_path": "mattersim-v1.0.0-1m"}),
("nequip", "cpu", {"model_path": NEQUIP_PATH}),
("nequip", "cpu", {"model": NEQUIP_PATH}),
("orb", "cpu", {}),
Expand Down Expand Up @@ -121,6 +123,7 @@ def test_invalid_arch():
("mace_mp", "/invalid/path"),
("chgnet", "/invalid/path"),
("dpa3", "/invalid/path"),
("mattersim", "/invalid/path"),
("nequip", "/invalid/path"),
("orb", "/invalid/path"),
("sevenn", "/invalid/path"),
Expand All @@ -131,7 +134,7 @@ def test_invalid_arch():
def test_invalid_model_path(arch, model_path):
"""Test error raised for invalid model_path."""
skip_extras(arch)
with pytest.raises((ValueError, RuntimeError, KeyError)):
with pytest.raises((ValueError, RuntimeError, KeyError, AssertionError)):
choose_calculator(arch=arch, model_path=model_path)


Expand All @@ -145,6 +148,11 @@ def test_invalid_model_path(arch, model_path):
{"arch": "chgnet", "model_path": CHGNET_PATH, "path": CHGNET_PATH},
{"arch": "dpa3", "model_path": DPA3_PATH, "model": DPA3_PATH},
{"arch": "dpa3", "model_path": DPA3_PATH, "path": DPA3_PATH},
{
"arch": "mattersim",
"model_path": "mattersim-v1.0.0-1m",
"path": "mattersim-v1.0.0-1m",
},
{"arch": "nequip", "model_path": NEQUIP_PATH, "model": NEQUIP_PATH},
{"arch": "nequip", "model_path": NEQUIP_PATH, "path": NEQUIP_PATH},
{"arch": "orb", "model_path": ORB_MODEL, "model": ORB_MODEL},
Expand Down
30 changes: 30 additions & 0 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
)
def test_potential_energy(struct, expected, properties, prop_key, calc_kwargs, idx):
"""Test single point energy using MACE calculators."""
skip_extras("mace")

calc_kwargs["model"] = MACE_PATH

single_point = SinglePoint(
struct=DATA_PATH / struct,
arch="mace",
Expand Down Expand Up @@ -68,6 +71,7 @@ def test_potential_energy(struct, expected, properties, prop_key, calc_kwargs, i
[
("chgnet", "cpu", -29.331436157226562, "NaCl.cif", {}),
("dpa3", "cpu", -27.053507387638092, "NaCl.cif", {"model_path": DPA3_PATH}),
("mattersim", "cpu", -27.06208038330078, "NaCl.cif", {}),
(
"nequip",
"cpu",
Expand Down Expand Up @@ -126,6 +130,8 @@ def test_extras(arch, device, expected_energy, struct, kwargs):

def test_single_point_none():
"""Test single point stress using MACE calculator."""
skip_extras("mace")

single_point = SinglePoint(
struct=DATA_PATH / "NaCl.cif",
arch="mace",
Expand All @@ -139,6 +145,8 @@ def test_single_point_none():

def test_single_point_clean():
"""Test single point stress using MACE calculator."""
skip_extras("mace")

single_point = SinglePoint(
struct=DATA_PATH / "H2O.cif",
arch="mace",
Expand All @@ -153,6 +161,8 @@ def test_single_point_clean():

def test_single_point_traj():
"""Test single point stress using MACE calculator."""
skip_extras("mace")

single_point = SinglePoint(
struct=DATA_PATH / "benzene-traj.xyz",
arch="mace",
Expand All @@ -174,6 +184,8 @@ def test_single_point_traj():

def test_single_point_write():
"""Test writing singlepoint results."""
skip_extras("mace")

data_path = DATA_PATH / "NaCl.cif"
results_path = Path("./NaCl-results.extxyz").absolute()
assert not results_path.exists()
Expand Down Expand Up @@ -208,6 +220,8 @@ def test_single_point_write():

def test_single_point_write_kwargs(tmp_path):
"""Test passing write_kwargs to singlepoint results."""
skip_extras("mace")

data_path = DATA_PATH / "NaCl.cif"
results_path = tmp_path / "NaCl.extxyz"

Expand All @@ -228,6 +242,8 @@ def test_single_point_write_kwargs(tmp_path):

def test_single_point_molecule(tmp_path):
"""Test singlepoint results for isolated molecule."""
skip_extras("mace")

data_path = DATA_PATH / "H2O.cif"
results_path = tmp_path / "H2O.extxyz"
single_point = SinglePoint(
Expand Down Expand Up @@ -255,6 +271,8 @@ def test_single_point_molecule(tmp_path):

def test_invalid_prop():
"""Test invalid property request."""
skip_extras("mace")

with pytest.raises(NotImplementedError):
SinglePoint(
struct=DATA_PATH / "H2O.cif",
Expand All @@ -266,6 +284,8 @@ def test_invalid_prop():

def test_atoms():
"""Test passing ASE Atoms structure."""
skip_extras("mace")

struct = read(DATA_PATH / "NaCl.cif")
single_point = SinglePoint(
struct=struct,
Expand All @@ -287,7 +307,10 @@ def test_no_atoms_or_path():

def test_invalidate_calc():
"""Test setting invalidate_calc via write_kwargs."""
skip_extras("mace")

struct = DATA_PATH / "NaCl.cif"

single_point = SinglePoint(
struct=struct,
arch="mace",
Expand All @@ -305,6 +328,8 @@ def test_invalidate_calc():

def test_logging(tmp_path):
"""Test attaching logger to SinglePoint and emissions are saved to info."""
skip_extras("mace")

log_file = tmp_path / "sp.log"

single_point = SinglePoint(
Expand All @@ -326,6 +351,8 @@ def test_logging(tmp_path):

def test_hessian():
"""Test Hessian."""
skip_extras("mace")

sp = SinglePoint(
calc_kwargs={"model": MACE_PATH},
struct=DATA_PATH / "NaCl.cif",
Expand All @@ -340,6 +367,8 @@ def test_hessian():

def test_hessian_traj():
"""Test calculating Hessian for trajectory."""
skip_extras("mace")

sp = SinglePoint(
calc_kwargs={"model": MACE_PATH},
struct=DATA_PATH / "benzene-traj.xyz",
Expand All @@ -359,6 +388,7 @@ def test_hessian_traj():
def test_hessian_not_implemented(struct):
"""Test unimplemented Hessian."""
skip_extras("chgnet")

with pytest.raises(NotImplementedError):
SinglePoint(
struct=DATA_PATH / struct,
Expand Down
2 changes: 2 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def skip_extras(arch: str):
pytest.importorskip("deepmd")
case "mace" | "mace_mp" | "mace_off":
pytest.importorskip("mace")
case "mattersim":
pytest.importorskip("mattersim")
case "nequip":
pytest.importorskip("nequip")
case "orb":
Expand Down

0 comments on commit adba856

Please sign in to comment.