Skip to content

sketch of hessian and vibrational analysis implementation#532

Open
M-R-Schaefer wants to merge 18 commits intomainfrom
hessian_pred
Open

sketch of hessian and vibrational analysis implementation#532
M-R-Schaefer wants to merge 18 commits intomainfrom
hessian_pred

Conversation

@M-R-Schaefer
Copy link
Copy Markdown
Contributor

@M-R-Schaefer M-R-Schaefer commented Feb 23, 2026

Summary by CodeRabbit

  • New Features

    • Analytical Hessian support: requestable at runtime, exposed via a public getter, returned in ASE-compatible shape.
    • Runtime overrides to enable/disable stress, Hessian, and uncertainty (force variance) during inference and MD runs.
  • Documentation

    • New vibrational analysis tutorial notebook and usage guide demonstrating analytical vs numerical Hessians and configuration examples.
  • Tests

    • New unit and integration tests validating Hessian computation and ASE Vibrations compatibility.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 23, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds optional analytical Hessian support end-to-end: config/model flags to enable Hessian, model changes to compute/export Hessians, ASECalculator integration and accessor, neighbor-function and builder plumbing, tests, and a tutorial notebook demonstrating analytical vs numerical Hessians.

Changes

Cohort / File(s) Summary
Model & config
apax/nn/models.py, apax/config/model_config.py
Added calc_hessian: bool = False to model config and model classes; models can compute and return "hessian" (via jax.hessian) when enabled.
Builder & simulation wiring
apax/nn/builder.py, apax/md/simulate.py
Extended build_energy_derivative_model(...) with calc_stress, calc_hessian, and force_variance; threaded these flags from MD config into model construction.
ASE calculator & neighbor fns
apax/md/ase_calc.py
Added calc_hessian/calc_stress/force_variance plumbing to neighbor-fn builders (build_energy_neighbor_fns, build_hessian_neighbor_fns), introduced _wrap_model, _initialize_hessian, hessian_step attribute, get_hessian(atoms), and conditional Hessian computation on JAX/non-JAX paths with 3N×3N reshaping.
Model builder classes
apax/nn/models.py, apax/nn/builder.py
EnergyDerivativeModel and ShallowEnsembleModel now expose calc_hessian and compute Hessians when requested; shallow ensemble passes force_variance.
Tests
tests/unit_tests/model/test_hessian.py, tests/integration_tests/md/test_ase_hessian.py, tests/unit_tests/md/test_ase_calc.py
Added unit/integration tests validating analytical Hessian vs numerical JAX/ASE results; updated test helpers to accept calc_hessian and assert implemented properties.
Docs & examples
examples/05_Vibrational_Analysis.ipynb, docs/source/_tutorials/05_Vibrational_Analysis.nblink, docs/source/_tutorials/index.rst, docs/source/usage/index.rst, docs/source/usage/property_prediction.rst
Added tutorial notebook and docs describing dynamic enabling of calc_stress/calc_hessian/force_variance, ASECalculator get_hessian, and vibrational-analysis workflow.
Misc tests
tests/unit_tests/md/test_create_openmm_simulation.py
Replaced exact equality checks with np.allclose for floating-point comparisons.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant ASECalc as ASECalculator
    participant Model as EnergyDerivativeModel / Ensemble
    participant JAX as JAX Runtime
    participant ASE as ASE Vibrations

    User->>ASECalc: calculate(atoms, properties=["hessian"])
    ASECalc->>ASECalc: ensure model wrapped (_wrap_model)
    ASECalc->>Model: call hessian_step(positions)
    Model->>JAX: evaluate jax.hessian(energy_fn)(positions)
    JAX-->>Model: Hessian tensor
    Model-->>ASECalc: Hessian (raw)
    ASECalc->>ASECalc: reshape to 3N×3N, store results["hessian"]
    ASECalc-->>User: return results including "hessian"
    User->>ASECalc: get_hessian(atoms)
    ASECalc-->>User: stored 3N×3N Hessian
    alt compare with ASE Vibrations
        User->>ASE: run Vibrations -> numerical Hessian
        ASE-->>User: numerical Hessian
        User->>User: compare analytical vs numerical
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 I hopped through tensors, second derivatives found,
JAX traced the energy, Hessians all around.
Atoms hum and vibrate, matrices in tune,
Analytic meets numeric beneath the moon.
Hop, compute, compare — a rabbit's boon.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 24.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: adding Hessian computation and vibrational analysis capabilities to the codebase, which represents the core of this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch hessian_pred

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
apax/nn/models.py (1)

145-145: Update EnergyDerivativeModel docstring to mention Hessian support.

Line 145 introduces calc_hessian; the class docstring still only mentions forces/stress. Consider updating it for accuracy.

📝 Suggested docstring tweak
-    """Transforms an EnergyModel into one that also predicts derivatives the total energy.
-    Can calculate forces and stress tensors.
-    """
+    """Transforms an EnergyModel into one that also predicts derivatives the total energy.
+    Can calculate forces, stress tensors, and Hessians.
+    """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@apax/nn/models.py` at line 145, The EnergyDerivativeModel docstring is out of
date: the class now has a calc_hessian boolean (added as calc_hessian: bool =
False) but the docstring only mentions forces/stress; update the
EnergyDerivativeModel class docstring to explicitly state that the model can
optionally compute Hessians when calc_hessian is True, describe the
returned/available Hessian data alongside forces and stress, and note any
performance/memory implications or shape/units conventions used for the Hessian
to keep the doc accurate.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@apax/md/ase_calc.py`:
- Around line 269-282: The Hessian pipeline currently builds hessian_model via
build_hessian_neighbor_fns and sets self.hessian_step directly, bypassing any
transformations or stress processing applied to self.model; update the code so
the hessian_model is passed through the same transformation/stress pipeline (the
same transforms or ProcessStress wrapper used when constructing self.model)
before creating self.hessian_step (or clearly document that Hessians are
computed on the raw model). Ensure you reference and reuse the existing
transformation chain applied to self.model (or invoke the same ProcessStress
wrapper) so get_hessian(), self.hessian_step, and the calculator’s reported
energies/forces remain consistent.

In `@examples/05_Vibrational_Analysis.ipynb`:
- Around line 204-213: The formatted print for frequencies appends a stray
backslash in the f-string; in the loop that prints first 10 frequencies (using
freqs_ana, freqs_num and diff) remove the trailing "\\" from the final print
call (the line printing f"{first:>15} {second:>15} {diff:.4f}\\") so it prints a
normal newline instead of a literal backslash; keep the rest of the formatting
unchanged.

---

Nitpick comments:
In `@apax/nn/models.py`:
- Line 145: The EnergyDerivativeModel docstring is out of date: the class now
has a calc_hessian boolean (added as calc_hessian: bool = False) but the
docstring only mentions forces/stress; update the EnergyDerivativeModel class
docstring to explicitly state that the model can optionally compute Hessians
when calc_hessian is True, describe the returned/available Hessian data
alongside forces and stress, and note any performance/memory implications or
shape/units conventions used for the Hessian to keep the doc accurate.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7926304 and 29c9b8e.

📒 Files selected for processing (7)
  • apax/md/ase_calc.py
  • apax/nn/models.py
  • docs/source/_tutorials/05_Vibrational_Analysis.nblink
  • docs/source/_tutorials/index.rst
  • examples/05_Vibrational_Analysis.ipynb
  • tests/integration_tests/md/test_ase_hessian.py
  • tests/unit_tests/model/test_hessian.py

Comment on lines +204 to +213
"print(\"\\nFirst 10 Frequencies (cm^-1):\")\n",
"print(f\"{'Analytical':>15} {'Numerical':>15} {'Diff':>10}\")\n",
"for i in range(10):\n",
" f_a = freqs_ana[i].real if freqs_ana[i].is_real else freqs_ana[i].imag * 1j\n",
" f_n = freqs_num[i].real if freqs_num[i].is_real else freqs_num[i].imag * 1j\n",
" diff = np.abs(f_a - f_n)\n",
" first = str(f_a)[:14]\n",
" second = str(f_n)[:14]\n",
" print(f\"{first:>15} {second:>15} {diff:.4f}\\\\\")"
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove stray backslash in frequency printout.

Line 212 prints an extra \ at the end of each line.

✂️ Suggested fix
-    print(f"{first:>15} {second:>15} {diff:.4f}\\")
+    print(f"{first:>15} {second:>15} {diff:.4f}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"print(\"\\nFirst 10 Frequencies (cm^-1):\")\n",
"print(f\"{'Analytical':>15} {'Numerical':>15} {'Diff':>10}\")\n",
"for i in range(10):\n",
" f_a = freqs_ana[i].real if freqs_ana[i].is_real else freqs_ana[i].imag * 1j\n",
" f_n = freqs_num[i].real if freqs_num[i].is_real else freqs_num[i].imag * 1j\n",
" diff = np.abs(f_a - f_n)\n",
" first = str(f_a)[:14]\n",
" second = str(f_n)[:14]\n",
" print(f\"{first:>15} {second:>15} {diff:.4f}\\\\\")"
]
"print(\"\\nFirst 10 Frequencies (cm^-1):\")\n",
"print(f\"{'Analytical':>15} {'Numerical':>15} {'Diff':>10}\")\n",
"for i in range(10):\n",
" f_a = freqs_ana[i].real if freqs_ana[i].is_real else freqs_ana[i].imag * 1j\n",
" f_n = freqs_num[i].real if freqs_num[i].is_real else freqs_num[i].imag * 1j\n",
" diff = np.abs(f_a - f_n)\n",
" first = str(f_a)[:14]\n",
" second = str(f_n)[:14]\n",
" print(f\"{first:>15} {second:>15} {diff:.4f}\")"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/05_Vibrational_Analysis.ipynb` around lines 204 - 213, The formatted
print for frequencies appends a stray backslash in the f-string; in the loop
that prints first 10 frequencies (using freqs_ana, freqs_num and diff) remove
the trailing "\\" from the final print call (the line printing f"{first:>15}
{second:>15} {diff:.4f}\\") so it prints a normal newline instead of a literal
backslash; keep the rest of the formatting unchanged.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/integration_tests/md/test_ase_hessian.py (2)

75-75: Redundant Hessian computation.

analytical_hessian on line 75 duplicates the computation already done on line 67. Consider reusing the hessian variable or removing the duplicate call.

♻️ Proposed fix
     # 1. Test direct hessian calculation via calculator
     hessian = calc.get_hessian(atoms)
     assert hessian.shape == (9, 9)
     assert not np.allclose(hessian, 0.0)

     # 2. Test compatibility with Vibrations module
     vib_dir = get_tmp_path / "vib"
     vib = Vibrations(atoms, name=str(vib_dir), nfree=4)

-    analytical_hessian = calc.get_hessian(atoms)
+    analytical_hessian = hessian  # Reuse from direct test above

     # Numerical Hessian from Vibrations (finite difference of forces)
     vib.run()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration_tests/md/test_ase_hessian.py` at line 75, The test calls
calc.get_hessian(atoms) twice: once assigned to hessian and again to
analytical_hessian; remove the redundant second call by reusing the existing
hessian variable (or rename hessian to analytical_hessian consistently) so there
is only a single calc.get_hessian(atoms) invocation; update assertions that
reference analytical_hessian to use hessian (or vice versa) and delete the
duplicate line containing analytical_hessian = calc.get_hessian(atoms).

66-69: Consider adding a Hessian symmetry assertion.

The Hessian matrix should be symmetric by definition. Adding a symmetry check would strengthen the test and catch potential implementation bugs.

💡 Suggested addition
     hessian = calc.get_hessian(atoms)
     assert hessian.shape == (9, 9)
     assert not np.allclose(hessian, 0.0)
+    # Hessian should be symmetric
+    assert np.allclose(hessian, hessian.T), "Hessian matrix is not symmetric"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration_tests/md/test_ase_hessian.py` around lines 66 - 69, After
computing the Hessian with calc.get_hessian(atoms) (variable hessian), add an
assertion that the matrix is symmetric by comparing it to its transpose using
numpy allclose (e.g., assert np.allclose(hessian, hessian.T, atol=1e-8) or a
suitable tolerance) to catch asymmetry bugs in the implementation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/integration_tests/md/test_ase_hessian.py`:
- Line 19: Rename the misspelled variable model_confg_path to model_config_path
and update all usages (e.g., where it's referenced on line 21) to the corrected
name so the test uses model_config_path consistently; ensure you only change the
identifier and not surrounding logic or values.

---

Nitpick comments:
In `@tests/integration_tests/md/test_ase_hessian.py`:
- Line 75: The test calls calc.get_hessian(atoms) twice: once assigned to
hessian and again to analytical_hessian; remove the redundant second call by
reusing the existing hessian variable (or rename hessian to analytical_hessian
consistently) so there is only a single calc.get_hessian(atoms) invocation;
update assertions that reference analytical_hessian to use hessian (or vice
versa) and delete the duplicate line containing analytical_hessian =
calc.get_hessian(atoms).
- Around line 66-69: After computing the Hessian with calc.get_hessian(atoms)
(variable hessian), add an assertion that the matrix is symmetric by comparing
it to its transpose using numpy allclose (e.g., assert np.allclose(hessian,
hessian.T, atol=1e-8) or a suitable tolerance) to catch asymmetry bugs in the
implementation.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 29c9b8e and ea17776.

📒 Files selected for processing (3)
  • examples/05_Vibrational_Analysis.ipynb
  • tests/integration_tests/md/test_ase_hessian.py
  • tests/unit_tests/model/test_hessian.py
✅ Files skipped from review due to trivial changes (1)
  • examples/05_Vibrational_Analysis.ipynb
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit_tests/model/test_hessian.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
apax/md/ase_calc.py (1)

254-276: ⚠️ Potential issue | 🟠 Major

Reset hessian_step on reinitialize to avoid stale Hessians.

When initialize() is called due to atomic number or cell changes, self.hessian_step can still point at the old Z/periodicity, yielding incorrect Hessians.

✅ Proposed fix
         model = self._wrap_model(model)

         self.model = model
+        self.hessian_step = None
         self.step = get_step_fn(
             model,
             jnp.asarray(atoms.numbers),
             bool(np.any(atoms.cell.array > 1e-6)),
             self.neigbor_from_jax,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@apax/md/ase_calc.py` around lines 254 - 276, The initialize method may leave
a stale Hessian reference in self.hessian_step when atoms (Z or cell) change;
update initialize to reset self.hessian_step (e.g., set to None or 0) so any
cached Hessian is invalidated on reinitialize; locate the initialize method and
add a line to clear self.hessian_step before rebuilding model/step/neighbor_fn
(symbols: initialize, self.hessian_step, _wrap_model, get_step_fn).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@apax/config/model_config.py`:
- Around line 186-187: The BaseModelConfig docstring is missing documentation
for the new user-facing flag calc_hessian; update the BaseModelConfig class
docstring's Parameters section to add an entry for calc_hessian (similar style
to calc_stress) describing its type (bool), default (False), and
purpose/behavior; reference the calc_hessian and calc_stress attributes in the
text so users understand the difference and any interactions or side effects.

---

Outside diff comments:
In `@apax/md/ase_calc.py`:
- Around line 254-276: The initialize method may leave a stale Hessian reference
in self.hessian_step when atoms (Z or cell) change; update initialize to reset
self.hessian_step (e.g., set to None or 0) so any cached Hessian is invalidated
on reinitialize; locate the initialize method and add a line to clear
self.hessian_step before rebuilding model/step/neighbor_fn (symbols: initialize,
self.hessian_step, _wrap_model, get_step_fn).

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ea17776 and 671c7b9.

📒 Files selected for processing (3)
  • apax/config/model_config.py
  • apax/md/ase_calc.py
  • tests/unit_tests/md/test_ase_calc.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
apax/md/ase_calc.py (1)

319-365: ⚠️ Potential issue | 🔴 Critical

initialize does not reset self.hessian_step, leaving a stale closure when atoms change.

get_step_fn closes over Z = jnp.asarray(atoms.numbers). When initialize is called again (e.g., because "numbers" or "cell" changed in calculate), self.step is rebuilt with the new Z, but self.hessian_step keeps the old captured Z. Every subsequent Hessian request then silently computes the Hessian for the wrong species.

🐛 Proposed fix
 def initialize(self, atoms):
     box = jnp.asarray(atoms.cell.array, dtype=jnp.float64)
     self.r_max = self.model_config.model.basis.r_max
     self.neigbor_from_jax = neighbor_calculable_with_jax(box, self.r_max)
     model, neighbor_fn = build_energy_neighbor_fns(
         ...
     )
     model = self._wrap_model(model)
     self.model = model
     self.step = get_step_fn(
         model,
         jnp.asarray(atoms.numbers),
         bool(np.any(atoms.cell.array > 1e-6)),
         self.neigbor_from_jax,
     )
+    self.hessian_step = None  # invalidate stale Hessian closure
     self.neighbor_fn = neighbor_fn
     ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@apax/md/ase_calc.py` around lines 319 - 365, The initialize method rebuilds
self.step using get_step_fn but does not reset/rebuild self.hessian_step, so any
previously captured Z remains stale; fix by resetting or recomputing
self.hessian_step inside initialize (e.g. set self.hessian_step = None or call
get_step_fn to create a fresh hessian step using the new
jnp.asarray(atoms.numbers) and same neighbor/flags) so Hessian calculations use
the updated species vector; update the block that assigns self.step (inside
initialize) to also update self.hessian_step accordingly.
♻️ Duplicate comments (1)
tests/integration_tests/md/test_ase_hessian.py (1)

19-21: Typo in variable name: model_confg_path should be model_config_path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration_tests/md/test_ase_hessian.py` around lines 19 - 21, There
is a typo: the variable name model_confg_path should be model_config_path;
rename model_confg_path to model_config_path everywhere it’s defined and
referenced (e.g., the assignment using TEST_PATH and the subsequent with
open(model_confg_path.as_posix(), "r") block) so all usages (opening the file,
assertions, or helper calls) use the corrected identifier model_config_path.
🧹 Nitpick comments (1)
apax/md/ase_calc.py (1)

504-521: Double Hessian computation when calc_hessian=True.

When ASECalculator is created with calc_hessian=True, initialize builds the main model with calc_hessian=True (via build_energy_neighbor_fns), so self.step(...) already includes "hessian" in its output. The subsequent hessian_step call then recomputes and overwrites the same value — doubling the cost of the most expensive operation.

Consider skipping the hessian_step path when "hessian" is already present in results:

♻️ Proposed fix (JAX path; mirror for non-JAX)
 results, self.neighbors = self.step(positions, self.neighbors, box)
 ...
-if "hessian" in properties:
+if "hessian" in properties and "hessian" not in results:
     if self.hessian_step is None:
         self._initialize_hessian(atoms)
     hessian_results, _ = self.hessian_step(positions, self.neighbors, box)
     results["hessian"] = hessian_results["hessian"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@apax/md/ase_calc.py` around lines 504 - 521, The code currently always calls
self.hessian_step and overwrites results["hessian"] even when self.step already
produced a "hessian" (causing double work); update both branches (the JAX branch
where results = self.step(positions, ...) and the non-JAX branch) to skip
invoking self.hessian_step when "hessian" is already present in results: check
if "hessian" not in results before calling self._initialize_hessian(...) /
self.hessian_step(...), and only set results["hessian"] from hessian_results
when you actually ran hessian_step; keep use of self._initialize_hessian,
self.hessian_step, self.step, and the existing signature/arguments unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@apax/md/simulate.py`:
- Around line 543-546: The new loop that sets force_variance by iterating
md_config.dynamics_checks can raise TypeError if dynamics_checks is None; update
the logic to check that md_config.dynamics_checks is truthy before iterating
(e.g., only run the for check in md_config.dynamics_checks when
md_config.dynamics_checks is not None/empty) so force_variance is determined
safely (preserve the existing initial assignment force_variance =
"forces_uncertainty" in md_config.properties and the loop that checks check.name
== "forces_uncertainty").

In `@docs/source/index.rst`:
- Line 6: Fix the stray period in the documentation sentence: locate the
sentence containing "`Apax` is a high-performance, extendable package for
training of and inference with atomistic neural networks., most prominently the
Gaussian Moment Neural Network model." and remove the errant '.' so it reads
"...atomistic neural networks, most prominently the Gaussian Moment Neural
Network model." ensuring the punctuation is correct and no extra characters
remain.

In `@docs/source/usage/property_prediction.rst`:
- Around line 34-39: The example incorrectly calls atoms.get_hessian()
(get_hessian is on the ASECalculator, not Atoms) and then passes a 2D (3N×3N)
matrix into VibrationsData.__init__ which expects a 4-D (N,3,N,3) Hessian; fix
by obtaining the Hessian from the calculator (call the ASECalculator instance's
get_hessian method) and then construct the vibrational data with
VibrationsData.from_2d(atoms, hessian) instead of VibrationsData(atoms,
hessian).

---

Outside diff comments:
In `@apax/md/ase_calc.py`:
- Around line 319-365: The initialize method rebuilds self.step using
get_step_fn but does not reset/rebuild self.hessian_step, so any previously
captured Z remains stale; fix by resetting or recomputing self.hessian_step
inside initialize (e.g. set self.hessian_step = None or call get_step_fn to
create a fresh hessian step using the new jnp.asarray(atoms.numbers) and same
neighbor/flags) so Hessian calculations use the updated species vector; update
the block that assigns self.step (inside initialize) to also update
self.hessian_step accordingly.

---

Duplicate comments:
In `@tests/integration_tests/md/test_ase_hessian.py`:
- Around line 19-21: There is a typo: the variable name model_confg_path should
be model_config_path; rename model_confg_path to model_config_path everywhere
it’s defined and referenced (e.g., the assignment using TEST_PATH and the
subsequent with open(model_confg_path.as_posix(), "r") block) so all usages
(opening the file, assertions, or helper calls) use the corrected identifier
model_config_path.

---

Nitpick comments:
In `@apax/md/ase_calc.py`:
- Around line 504-521: The code currently always calls self.hessian_step and
overwrites results["hessian"] even when self.step already produced a "hessian"
(causing double work); update both branches (the JAX branch where results =
self.step(positions, ...) and the non-JAX branch) to skip invoking
self.hessian_step when "hessian" is already present in results: check if
"hessian" not in results before calling self._initialize_hessian(...) /
self.hessian_step(...), and only set results["hessian"] from hessian_results
when you actually ran hessian_step; keep use of self._initialize_hessian,
self.hessian_step, self.step, and the existing signature/arguments unchanged.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 671c7b9 and a99e750.

📒 Files selected for processing (9)
  • apax/md/ase_calc.py
  • apax/md/simulate.py
  • apax/nn/builder.py
  • docs/source/index.rst
  • docs/source/usage/index.rst
  • docs/source/usage/property_prediction.rst
  • tests/integration_tests/md/test_ase_hessian.py
  • tests/unit_tests/md/test_ase_calc.py
  • tests/unit_tests/md/test_create_openmm_simulation.py
✅ Files skipped from review due to trivial changes (1)
  • docs/source/usage/index.rst

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
apax/md/ase_calc.py (2)

119-122: ⚠️ Potential issue | 🟠 Major

stress_ensemble is incorrectly overwritten with transposed forces_ensemble.

The second if "forces_ensemble" block (lines 119–122) has two problems:

  1. The condition checks "forces_ensemble" instead of "stress_ensemble".
  2. It transposes the already-transposed forces_ensemble (mutated in lines 116–118) and assigns it to stress_ensemble, silently corrupting the per-model stress ensemble output.

This pre-existing bug is now exercised on every code path that goes through _wrap_model for a full ensemble.

🐛 Proposed fix
-        if "forces_ensemble" in ensemble.keys():
-            ensemble["stress_ensemble"] = jnp.transpose(
-                ensemble["forces_ensemble"], (1, 2, 0)
-            )
+        if "stress_ensemble" in ensemble.keys():
+            ensemble["stress_ensemble"] = jnp.transpose(
+                ensemble["stress_ensemble"], (1, 2, 0)
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@apax/md/ase_calc.py` around lines 119 - 122, The bug is that the second
conditional incorrectly checks for "forces_ensemble" and assigns a transposed
forces array to stress_ensemble, corrupting per-model stress outputs; in
_wrap_model, change the second if to check "stress_ensemble" (not
"forces_ensemble") and ensure it assigns/constructs ensemble["stress_ensemble"]
from the correct source (e.g., transpose the original per-model stress data or
skip reusing forces_ensemble), so that ensemble["stress_ensemble"] is derived
from the proper stress data rather than the already-transposed forces_ensemble.

329-343: ⚠️ Potential issue | 🟡 Minor

Clarify design intent for calc_hessian parameter in main model vs. hessian_step.

The main model can emit Hessian in the forward pass depending on calc_hessian (lines 173–175 in models.py). When calc_hessian=False (default), the main model is lightweight and _initialize_hessian lazily builds a separate model with calc_hessian=True only when Hessian is requested (guards at lines 505, 516). When calc_hessian=True, the main model itself produces Hessian, making _initialize_hessian unnecessary.

Add a comment explaining this design choice: why users might set calc_hessian=True (to compute Hessian in the main forward pass once and reuse it) vs. relying on lazy initialization (to keep the main model cheap by default and only pay Hessian cost when needed).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@apax/md/ase_calc.py` around lines 329 - 343, Add a short design-intent
comment above the self.step / self.hessian_step setup in ase_calc.py explaining
the semantics of the calc_hessian flag: when calc_hessian=True the main model
(returned by _wrap_model / passed into get_step_fn) will compute and reuse the
Hessian in its forward pass so _initialize_hessian is unnecessary, whereas the
default calc_hessian=False keeps the main model lightweight and defers building
a separate Hessian-capable model lazily via _initialize_hessian only when a
Hessian is requested; mention that users should set calc_hessian=True to pay the
Hessian cost once and reuse it, or leave it False to avoid Hessian overhead
until needed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@apax/md/ase_calc.py`:
- Around line 165-175: The example uses a non-existent ASE Atoms method
atoms.get_hessian(); update the example to call the calculator method instead by
invoking ASECalculator.get_hessian (the calc.get_hessian method implemented
around the get_hessian definition) — e.g. replace atoms.get_hessian() with
calc.get_hessian(atoms) or calc.get_hessian(atoms=atoms) and adjust any
docstrings/comments that reference atoms.get_hessian() accordingly.
- Around line 266-269: Remove the unused local variable is_shallow_ensemble that
is assigned from self.model_config.model.ensemble and its kind check; delete the
three-line assignment (the is_shallow_ensemble name and its computation) from
the function so the code no longer defines an unused variable and Ruff F841 is
resolved. If the ensemble-kind check is needed elsewhere, replace usages with
direct checks against self.model_config.model.ensemble.kind instead of
introducing a new unused variable.

---

Outside diff comments:
In `@apax/md/ase_calc.py`:
- Around line 119-122: The bug is that the second conditional incorrectly checks
for "forces_ensemble" and assigns a transposed forces array to stress_ensemble,
corrupting per-model stress outputs; in _wrap_model, change the second if to
check "stress_ensemble" (not "forces_ensemble") and ensure it assigns/constructs
ensemble["stress_ensemble"] from the correct source (e.g., transpose the
original per-model stress data or skip reusing forces_ensemble), so that
ensemble["stress_ensemble"] is derived from the proper stress data rather than
the already-transposed forces_ensemble.
- Around line 329-343: Add a short design-intent comment above the self.step /
self.hessian_step setup in ase_calc.py explaining the semantics of the
calc_hessian flag: when calc_hessian=True the main model (returned by
_wrap_model / passed into get_step_fn) will compute and reuse the Hessian in its
forward pass so _initialize_hessian is unnecessary, whereas the default
calc_hessian=False keeps the main model lightweight and defers building a
separate Hessian-capable model lazily via _initialize_hessian only when a
Hessian is requested; mention that users should set calc_hessian=True to pay the
Hessian cost once and reuse it, or leave it False to avoid Hessian overhead
until needed.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a99e750 and 6b5ead6.

📒 Files selected for processing (5)
  • apax/md/ase_calc.py
  • apax/md/simulate.py
  • docs/source/index.rst
  • docs/source/usage/property_prediction.rst
  • tests/integration_tests/md/test_ase_hessian.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • apax/md/simulate.py
  • tests/integration_tests/md/test_ase_hessian.py
  • docs/source/index.rst
  • docs/source/usage/property_prediction.rst

empirical_corrections: list[EmpiricalCorrection] = []

calc_stress: bool = False
calc_hessian: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a critique, but have you considered:

output-properties:
- stress
- hessian

which would also allow for future

output-propterties:
stress: 
- analytical: true
...

return model

def _update_implemented_properties(self):
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I fully understand this function, isn't the capability of the model defined at __init__ (except for ensembles maybe?). Why dynamic? Out of scope of the PR, just curious

Comment on lines +543 to +547
force_variance = "forces_uncertainty" in md_config.properties
if md_config.dynamics_checks:
for check in md_config.dynamics_checks:
if check.name == "forces_uncertainty":
force_variance = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this true for all others as well, e.g. energy, forces, stresses, ... ?

Comment on lines +218 to +219
force_variance = self.config["ensemble"]["force_variance"]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potential key error if self.config["ensemble"] is None? Can this happen?

Comment on lines +31 to +36
positions = np.array(
[[0.0, 0.0, 0.0], [0.0, 0.7, 0.5], [0.0, -0.7, 0.5]], dtype=np.float64
)
atomic_numbers = np.array([8, 1, 1])
box = np.array([0.0, 0.0, 0.0], dtype=np.float64)
atoms = Atoms(atomic_numbers, positions, cell=box)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract into pytest.fixture?

Comment on lines +49 to +50
idx = jnp.array([[1, 2, 0, 2, 0, 1], [0, 0, 1, 1, 2, 2]])
offsets = jnp.zeros((6, 3), dtype=jnp.float64)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is idx and offsets?



def test_hessian_prediction():
# Use float64 for better precision in Hessian
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the test states use float64, how does the model interally treat the precision? Is this default / configurable?

Comment on lines +155 to +156
/*.py
*.dvc
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

???

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants