Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ atoms.calc = NequixCalculator("nequix-mp-1", backend="torch")

These are typically comparable in speed with kernels.

Analytical Hessians can be calculated with (currently only supported for JAX backend):

```python
calc = NequixCalculator("nequix-mp-1", backend="jax")
calc.get_hessian(atoms) # np array of shape (n, n, 3, 3)
```

#### NequixCalculator

Arguments
Expand Down
60 changes: 37 additions & 23 deletions nequix/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
dict_to_pytorch_geometric,
preprocess_graph,
)
from nequix.torch_impl.model import NequixTorch
from nequix.model import Nequix
from nequix.model import load_model as load_model_jax
from nequix.model import save_model as save_model_jax
from nequix.pft.hessian import hessian_linearized
from nequix.torch_impl.model import NequixTorch


def from_pretrained(
Expand Down Expand Up @@ -128,33 +129,33 @@ def __init__(
self._capacity_multiplier = capacity_multiplier
self.backend = backend

def _pad_graph_jax(self, graph, numbers_changed=False):
# maintain edge capacity with _capacity_multiplier over edges,
# recalculate if numbers (system) changes, or if the capacity is exceeded
if self._capacity is None or numbers_changed or graph.n_edge[0] > self._capacity:
raw = int(np.ceil(graph.n_edge[0] * self._capacity_multiplier))
# round up edges to the nearest multiple of 64
# NB: this avoids excessive recompilation in high-throughput
# workflows (e.g. material relaxtions) but this number may need
# to be tuned depending on the system sizes
self._capacity = ((raw + 63) // 64) * 64

# round up nodes to the nearest multiple of 8
# NB: this avoids excessive recompilation in high-throughput
# workflows (e.g. material relaxtions) but this number may need to
# be tuned depending on the system sizes
n_node = ((graph.n_node[0] + 8) // 8) * 8

# pad the graph
graph = jraph.pad_with_graphs(graph, n_node=n_node, n_edge=self._capacity, n_graph=2)
return graph

def calculate(self, atoms=None, properties=None, system_changes=all_changes):
Calculator.calculate(self, atoms)
processed_graph = preprocess_graph(atoms, self.atom_indices, self.cutoff, False)
if self.backend == "jax":
graph = dict_to_graphstuple(processed_graph)
# maintain edge capacity with _capacity_multiplier over edges,
# recalculate if numbers (system) changes, or if the capacity is exceeded
if (
self._capacity is None
or ("numbers" in system_changes)
or graph.n_edge[0] > self._capacity
):
raw = int(np.ceil(graph.n_edge[0] * self._capacity_multiplier))
# round up edges to the nearest multiple of 64
# NB: this avoids excessive recompilation in high-throughput
# workflows (e.g. material relaxtions) but this number may need
# to be tuned depending on the system sizes
self._capacity = ((raw + 63) // 64) * 64

# round up nodes to the nearest multiple of 8
# NB: this avoids excessive recompilation in high-throughput
# workflows (e.g. material relaxtions) but this number may need to
# be tuned depending on the system sizes
n_node = ((graph.n_node[0] + 8) // 8) * 8

# pad the graph
graph = jraph.pad_with_graphs(graph, n_node=n_node, n_edge=self._capacity, n_graph=2)
graph = self._pad_graph_jax(graph, "numbers" in system_changes)
energy, forces, stress = eqx.filter_jit(self.model)(graph)
forces = forces[: len(atoms)]

Expand Down Expand Up @@ -216,3 +217,16 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
self.results["stress"] = (
full_3x3_to_voigt_6_stress(np.array(stress[0])) if stress is not None else None
)

def get_hessian(self, atoms=None):
assert self.backend == "jax", "Hessian calculation currently only supported for JAX backend"
if atoms is None and self.atoms is None:
raise ValueError("atoms not set")
if atoms is None:
atoms = self.atoms
n_atoms = len(atoms)
processed_graph = preprocess_graph(atoms, self.atom_indices, self.cutoff, False)
graph = dict_to_graphstuple(processed_graph)
graph = self._pad_graph_jax(graph, True)
hessian = eqx.filter_jit(hessian_linearized)(self.model, graph)
return np.array(hessian[:n_atoms, :n_atoms, :, :], copy=True)
2 changes: 0 additions & 2 deletions nequix/pft/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import equinox as eqx


@eqx.filter_jit
def hessian_linearized(model, graph, batch_size=None):
cell_per_edge = jnp.repeat(
graph.globals["cell"],
Expand All @@ -22,7 +21,6 @@ def total_energy_fn(positions):

pos = graph.nodes["positions"]
_, hvp = jax.linearize(jax.grad(total_energy_fn), pos)
hvp = jax.jit(hvp)
basis = jnp.eye(pos.shape[0] * pos.shape[1]).reshape(-1, *pos.shape)
return (
jax.lax.map(hvp, basis, batch_size=batch_size)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jax-md = { git = "https://github.com/jax-md/jax-md.git" }
[project.optional-dependencies]
torch = [
"e3nn>=0.5.8",
"openequivariance>=0.5.4",
"openequivariance>=0.6.4",
"torch>=2.7.0",
"torch-geometric>=2.6.1",
"setuptools",
Expand All @@ -75,7 +75,7 @@ torch-sim = [
oeq = [
# You need to install openequivariance_extjax separately
# uv pip install openequivariance_extjax --no-build-isolation
"openequivariance[jax]>=0.5.4",
"openequivariance[jax]>=0.6.4",
]
pft = [
"phonopy>=2.43.1",
Expand Down
10 changes: 10 additions & 0 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,13 @@ def test_calculator_nequix_mp_1_without_cell(backend, kernel):
assert np.isfinite(energy)
assert forces.shape == (len(atoms), 3)
assert np.all(np.isfinite(forces))


@pytest.mark.parametrize("kernel", [False, pytest.param(True, marks=skip_no_oeq)])
def test_calculator_hessian(kernel):
atoms = si()
calc = NequixCalculator(model_name="nequix-mp-1", backend="jax", use_kernel=kernel)
hessian = calc.get_hessian(atoms)
print(hessian)
assert hessian.shape == (len(atoms), len(atoms), 3, 3)
assert np.all(np.isfinite(hessian))
12 changes: 6 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.