diff --git a/README.md b/README.md index 7608b4d..4246738 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,13 @@ atoms.calc = NequixCalculator("nequix-mp-1", backend="torch") These are typically comparable in speed with kernels. +Analytical Hessians can be calculated with (currently only supported for JAX backend): + +```python +calc = NequixCalculator("nequix-mp-1", backend="jax") +calc.get_hessian(atoms) # np array of shape (n, n, 3, 3) +``` + #### NequixCalculator Arguments diff --git a/nequix/calculator.py b/nequix/calculator.py index 2336f0d..302648f 100644 --- a/nequix/calculator.py +++ b/nequix/calculator.py @@ -13,10 +13,11 @@ dict_to_pytorch_geometric, preprocess_graph, ) -from nequix.torch_impl.model import NequixTorch from nequix.model import Nequix from nequix.model import load_model as load_model_jax from nequix.model import save_model as save_model_jax +from nequix.pft.hessian import hessian_linearized +from nequix.torch_impl.model import NequixTorch def from_pretrained( @@ -128,33 +129,33 @@ def __init__( self._capacity_multiplier = capacity_multiplier self.backend = backend + def _pad_graph_jax(self, graph, numbers_changed=False): + # maintain edge capacity with _capacity_multiplier over edges, + # recalculate if numbers (system) changes, or if the capacity is exceeded + if self._capacity is None or numbers_changed or graph.n_edge[0] > self._capacity: + raw = int(np.ceil(graph.n_edge[0] * self._capacity_multiplier)) + # round up edges to the nearest multiple of 64 + # NB: this avoids excessive recompilation in high-throughput + # workflows (e.g. material relaxtions) but this number may need + # to be tuned depending on the system sizes + self._capacity = ((raw + 63) // 64) * 64 + + # round up nodes to the nearest multiple of 8 + # NB: this avoids excessive recompilation in high-throughput + # workflows (e.g. material relaxtions) but this number may need to + # be tuned depending on the system sizes + n_node = ((graph.n_node[0] + 8) // 8) * 8 + + # pad the graph + graph = jraph.pad_with_graphs(graph, n_node=n_node, n_edge=self._capacity, n_graph=2) + return graph + def calculate(self, atoms=None, properties=None, system_changes=all_changes): Calculator.calculate(self, atoms) processed_graph = preprocess_graph(atoms, self.atom_indices, self.cutoff, False) if self.backend == "jax": graph = dict_to_graphstuple(processed_graph) - # maintain edge capacity with _capacity_multiplier over edges, - # recalculate if numbers (system) changes, or if the capacity is exceeded - if ( - self._capacity is None - or ("numbers" in system_changes) - or graph.n_edge[0] > self._capacity - ): - raw = int(np.ceil(graph.n_edge[0] * self._capacity_multiplier)) - # round up edges to the nearest multiple of 64 - # NB: this avoids excessive recompilation in high-throughput - # workflows (e.g. material relaxtions) but this number may need - # to be tuned depending on the system sizes - self._capacity = ((raw + 63) // 64) * 64 - - # round up nodes to the nearest multiple of 8 - # NB: this avoids excessive recompilation in high-throughput - # workflows (e.g. material relaxtions) but this number may need to - # be tuned depending on the system sizes - n_node = ((graph.n_node[0] + 8) // 8) * 8 - - # pad the graph - graph = jraph.pad_with_graphs(graph, n_node=n_node, n_edge=self._capacity, n_graph=2) + graph = self._pad_graph_jax(graph, "numbers" in system_changes) energy, forces, stress = eqx.filter_jit(self.model)(graph) forces = forces[: len(atoms)] @@ -216,3 +217,16 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): self.results["stress"] = ( full_3x3_to_voigt_6_stress(np.array(stress[0])) if stress is not None else None ) + + def get_hessian(self, atoms=None): + assert self.backend == "jax", "Hessian calculation currently only supported for JAX backend" + if atoms is None and self.atoms is None: + raise ValueError("atoms not set") + if atoms is None: + atoms = self.atoms + n_atoms = len(atoms) + processed_graph = preprocess_graph(atoms, self.atom_indices, self.cutoff, False) + graph = dict_to_graphstuple(processed_graph) + graph = self._pad_graph_jax(graph, True) + hessian = eqx.filter_jit(hessian_linearized)(self.model, graph) + return np.array(hessian[:n_atoms, :n_atoms, :, :], copy=True) diff --git a/nequix/pft/hessian.py b/nequix/pft/hessian.py index a0a426c..c61d3f1 100644 --- a/nequix/pft/hessian.py +++ b/nequix/pft/hessian.py @@ -3,7 +3,6 @@ import equinox as eqx -@eqx.filter_jit def hessian_linearized(model, graph, batch_size=None): cell_per_edge = jnp.repeat( graph.globals["cell"], @@ -22,7 +21,6 @@ def total_energy_fn(positions): pos = graph.nodes["positions"] _, hvp = jax.linearize(jax.grad(total_energy_fn), pos) - hvp = jax.jit(hvp) basis = jnp.eye(pos.shape[0] * pos.shape[1]).reshape(-1, *pos.shape) return ( jax.lax.map(hvp, basis, batch_size=batch_size) diff --git a/pyproject.toml b/pyproject.toml index a9aa9a0..56b68ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ jax-md = { git = "https://github.com/jax-md/jax-md.git" } [project.optional-dependencies] torch = [ "e3nn>=0.5.8", - "openequivariance>=0.5.4", + "openequivariance>=0.6.4", "torch>=2.7.0", "torch-geometric>=2.6.1", "setuptools", @@ -75,7 +75,7 @@ torch-sim = [ oeq = [ # You need to install openequivariance_extjax separately # uv pip install openequivariance_extjax --no-build-isolation - "openequivariance[jax]>=0.5.4", + "openequivariance[jax]>=0.6.4", ] pft = [ "phonopy>=2.43.1", diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 0ca2dbf..8a6955d 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -158,3 +158,13 @@ def test_calculator_nequix_mp_1_without_cell(backend, kernel): assert np.isfinite(energy) assert forces.shape == (len(atoms), 3) assert np.all(np.isfinite(forces)) + + +@pytest.mark.parametrize("kernel", [False, pytest.param(True, marks=skip_no_oeq)]) +def test_calculator_hessian(kernel): + atoms = si() + calc = NequixCalculator(model_name="nequix-mp-1", backend="jax", use_kernel=kernel) + hessian = calc.get_hessian(atoms) + print(hessian) + assert hessian.shape == (len(atoms), len(atoms), 3, 3) + assert np.all(np.isfinite(hessian)) diff --git a/uv.lock b/uv.lock index 4c67e57..778e1a8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.12' and sys_platform == 'linux'", @@ -2296,8 +2296,8 @@ requires-dist = [ { name = "jax", extras = ["cuda12"], marker = "sys_platform == 'linux'", specifier = ">=0.4.34" }, { name = "jraph", specifier = ">=0.0.6.dev0" }, { name = "matscipy", specifier = ">=1.1.1" }, - { name = "openequivariance", marker = "extra == 'torch'", specifier = ">=0.5.4" }, - { name = "openequivariance", extras = ["jax"], marker = "extra == 'oeq'", specifier = ">=0.5.4" }, + { name = "openequivariance", marker = "extra == 'torch'", specifier = ">=0.6.4" }, + { name = "openequivariance", extras = ["jax"], marker = "extra == 'oeq'", specifier = ">=0.6.4" }, { name = "optax", specifier = ">=0.2.5" }, { name = "phonopy", marker = "extra == 'pft'", specifier = ">=2.43.1" }, { name = "pyyaml", specifier = ">=6.0.2" }, @@ -2666,7 +2666,7 @@ wheels = [ [[package]] name = "openequivariance" -version = "0.5.4" +version = "0.6.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2" }, @@ -2674,7 +2674,7 @@ dependencies = [ { name = "numpy" }, { name = "setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ab/29/1e76c9e8bd02ca4f8c7f67a31bdb09fe1960a99d9fc63e3bfedc9725f7ea/openequivariance-0.5.4.tar.gz", hash = "sha256:4454524d2acee2c4a255ba492a1fb7160be8e4631c955401ad22d86c63812110", size = 105102, upload-time = "2026-02-10T03:54:25.485Z" } +sdist = { url = "https://files.pythonhosted.org/packages/be/d7/a3166804835d6f37bdc32055747bb16683473f36d72edaba7f6caf0c1600/openequivariance-0.6.4.tar.gz", hash = "sha256:225a871ce375bca0ea9d96c54c77a54e6b1a8073b10be20378dc24a311400804", size = 100702, upload-time = "2026-03-06T04:20:05.002Z" } [package.optional-dependencies] jax = [ @@ -2929,7 +2929,7 @@ name = "pexpect" version = "4.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ptyprocess" }, + { name = "ptyprocess", marker = "sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } wheels = [