sketch of hessian and vibrational analysis implementation#532
sketch of hessian and vibrational analysis implementation#532M-R-Schaefer wants to merge 18 commits intomainfrom
Conversation
for more information, see https://pre-commit.ci
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds optional analytical Hessian support end-to-end: config/model flags to enable Hessian, model changes to compute/export Hessians, ASECalculator integration and accessor, neighbor-function and builder plumbing, tests, and a tutorial notebook demonstrating analytical vs numerical Hessians. Changes
Sequence DiagramsequenceDiagram
participant User as User Code
participant ASECalc as ASECalculator
participant Model as EnergyDerivativeModel / Ensemble
participant JAX as JAX Runtime
participant ASE as ASE Vibrations
User->>ASECalc: calculate(atoms, properties=["hessian"])
ASECalc->>ASECalc: ensure model wrapped (_wrap_model)
ASECalc->>Model: call hessian_step(positions)
Model->>JAX: evaluate jax.hessian(energy_fn)(positions)
JAX-->>Model: Hessian tensor
Model-->>ASECalc: Hessian (raw)
ASECalc->>ASECalc: reshape to 3N×3N, store results["hessian"]
ASECalc-->>User: return results including "hessian"
User->>ASECalc: get_hessian(atoms)
ASECalc-->>User: stored 3N×3N Hessian
alt compare with ASE Vibrations
User->>ASE: run Vibrations -> numerical Hessian
ASE-->>User: numerical Hessian
User->>User: compare analytical vs numerical
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
apax/nn/models.py (1)
145-145: Update EnergyDerivativeModel docstring to mention Hessian support.Line 145 introduces
calc_hessian; the class docstring still only mentions forces/stress. Consider updating it for accuracy.📝 Suggested docstring tweak
- """Transforms an EnergyModel into one that also predicts derivatives the total energy. - Can calculate forces and stress tensors. - """ + """Transforms an EnergyModel into one that also predicts derivatives the total energy. + Can calculate forces, stress tensors, and Hessians. + """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@apax/nn/models.py` at line 145, The EnergyDerivativeModel docstring is out of date: the class now has a calc_hessian boolean (added as calc_hessian: bool = False) but the docstring only mentions forces/stress; update the EnergyDerivativeModel class docstring to explicitly state that the model can optionally compute Hessians when calc_hessian is True, describe the returned/available Hessian data alongside forces and stress, and note any performance/memory implications or shape/units conventions used for the Hessian to keep the doc accurate.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@apax/md/ase_calc.py`:
- Around line 269-282: The Hessian pipeline currently builds hessian_model via
build_hessian_neighbor_fns and sets self.hessian_step directly, bypassing any
transformations or stress processing applied to self.model; update the code so
the hessian_model is passed through the same transformation/stress pipeline (the
same transforms or ProcessStress wrapper used when constructing self.model)
before creating self.hessian_step (or clearly document that Hessians are
computed on the raw model). Ensure you reference and reuse the existing
transformation chain applied to self.model (or invoke the same ProcessStress
wrapper) so get_hessian(), self.hessian_step, and the calculator’s reported
energies/forces remain consistent.
In `@examples/05_Vibrational_Analysis.ipynb`:
- Around line 204-213: The formatted print for frequencies appends a stray
backslash in the f-string; in the loop that prints first 10 frequencies (using
freqs_ana, freqs_num and diff) remove the trailing "\\" from the final print
call (the line printing f"{first:>15} {second:>15} {diff:.4f}\\") so it prints a
normal newline instead of a literal backslash; keep the rest of the formatting
unchanged.
---
Nitpick comments:
In `@apax/nn/models.py`:
- Line 145: The EnergyDerivativeModel docstring is out of date: the class now
has a calc_hessian boolean (added as calc_hessian: bool = False) but the
docstring only mentions forces/stress; update the EnergyDerivativeModel class
docstring to explicitly state that the model can optionally compute Hessians
when calc_hessian is True, describe the returned/available Hessian data
alongside forces and stress, and note any performance/memory implications or
shape/units conventions used for the Hessian to keep the doc accurate.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
apax/md/ase_calc.pyapax/nn/models.pydocs/source/_tutorials/05_Vibrational_Analysis.nblinkdocs/source/_tutorials/index.rstexamples/05_Vibrational_Analysis.ipynbtests/integration_tests/md/test_ase_hessian.pytests/unit_tests/model/test_hessian.py
| "print(\"\\nFirst 10 Frequencies (cm^-1):\")\n", | ||
| "print(f\"{'Analytical':>15} {'Numerical':>15} {'Diff':>10}\")\n", | ||
| "for i in range(10):\n", | ||
| " f_a = freqs_ana[i].real if freqs_ana[i].is_real else freqs_ana[i].imag * 1j\n", | ||
| " f_n = freqs_num[i].real if freqs_num[i].is_real else freqs_num[i].imag * 1j\n", | ||
| " diff = np.abs(f_a - f_n)\n", | ||
| " first = str(f_a)[:14]\n", | ||
| " second = str(f_n)[:14]\n", | ||
| " print(f\"{first:>15} {second:>15} {diff:.4f}\\\\\")" | ||
| ] |
There was a problem hiding this comment.
Remove stray backslash in frequency printout.
Line 212 prints an extra \ at the end of each line.
✂️ Suggested fix
- print(f"{first:>15} {second:>15} {diff:.4f}\\")
+ print(f"{first:>15} {second:>15} {diff:.4f}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "print(\"\\nFirst 10 Frequencies (cm^-1):\")\n", | |
| "print(f\"{'Analytical':>15} {'Numerical':>15} {'Diff':>10}\")\n", | |
| "for i in range(10):\n", | |
| " f_a = freqs_ana[i].real if freqs_ana[i].is_real else freqs_ana[i].imag * 1j\n", | |
| " f_n = freqs_num[i].real if freqs_num[i].is_real else freqs_num[i].imag * 1j\n", | |
| " diff = np.abs(f_a - f_n)\n", | |
| " first = str(f_a)[:14]\n", | |
| " second = str(f_n)[:14]\n", | |
| " print(f\"{first:>15} {second:>15} {diff:.4f}\\\\\")" | |
| ] | |
| "print(\"\\nFirst 10 Frequencies (cm^-1):\")\n", | |
| "print(f\"{'Analytical':>15} {'Numerical':>15} {'Diff':>10}\")\n", | |
| "for i in range(10):\n", | |
| " f_a = freqs_ana[i].real if freqs_ana[i].is_real else freqs_ana[i].imag * 1j\n", | |
| " f_n = freqs_num[i].real if freqs_num[i].is_real else freqs_num[i].imag * 1j\n", | |
| " diff = np.abs(f_a - f_n)\n", | |
| " first = str(f_a)[:14]\n", | |
| " second = str(f_n)[:14]\n", | |
| " print(f\"{first:>15} {second:>15} {diff:.4f}\")" |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/05_Vibrational_Analysis.ipynb` around lines 204 - 213, The formatted
print for frequencies appends a stray backslash in the f-string; in the loop
that prints first 10 frequencies (using freqs_ana, freqs_num and diff) remove
the trailing "\\" from the final print call (the line printing f"{first:>15}
{second:>15} {diff:.4f}\\") so it prints a normal newline instead of a literal
backslash; keep the rest of the formatting unchanged.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/integration_tests/md/test_ase_hessian.py (2)
75-75: Redundant Hessian computation.
analytical_hessianon line 75 duplicates the computation already done on line 67. Consider reusing thehessianvariable or removing the duplicate call.♻️ Proposed fix
# 1. Test direct hessian calculation via calculator hessian = calc.get_hessian(atoms) assert hessian.shape == (9, 9) assert not np.allclose(hessian, 0.0) # 2. Test compatibility with Vibrations module vib_dir = get_tmp_path / "vib" vib = Vibrations(atoms, name=str(vib_dir), nfree=4) - analytical_hessian = calc.get_hessian(atoms) + analytical_hessian = hessian # Reuse from direct test above # Numerical Hessian from Vibrations (finite difference of forces) vib.run()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration_tests/md/test_ase_hessian.py` at line 75, The test calls calc.get_hessian(atoms) twice: once assigned to hessian and again to analytical_hessian; remove the redundant second call by reusing the existing hessian variable (or rename hessian to analytical_hessian consistently) so there is only a single calc.get_hessian(atoms) invocation; update assertions that reference analytical_hessian to use hessian (or vice versa) and delete the duplicate line containing analytical_hessian = calc.get_hessian(atoms).
66-69: Consider adding a Hessian symmetry assertion.The Hessian matrix should be symmetric by definition. Adding a symmetry check would strengthen the test and catch potential implementation bugs.
💡 Suggested addition
hessian = calc.get_hessian(atoms) assert hessian.shape == (9, 9) assert not np.allclose(hessian, 0.0) + # Hessian should be symmetric + assert np.allclose(hessian, hessian.T), "Hessian matrix is not symmetric"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration_tests/md/test_ase_hessian.py` around lines 66 - 69, After computing the Hessian with calc.get_hessian(atoms) (variable hessian), add an assertion that the matrix is symmetric by comparing it to its transpose using numpy allclose (e.g., assert np.allclose(hessian, hessian.T, atol=1e-8) or a suitable tolerance) to catch asymmetry bugs in the implementation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/integration_tests/md/test_ase_hessian.py`:
- Line 19: Rename the misspelled variable model_confg_path to model_config_path
and update all usages (e.g., where it's referenced on line 21) to the corrected
name so the test uses model_config_path consistently; ensure you only change the
identifier and not surrounding logic or values.
---
Nitpick comments:
In `@tests/integration_tests/md/test_ase_hessian.py`:
- Line 75: The test calls calc.get_hessian(atoms) twice: once assigned to
hessian and again to analytical_hessian; remove the redundant second call by
reusing the existing hessian variable (or rename hessian to analytical_hessian
consistently) so there is only a single calc.get_hessian(atoms) invocation;
update assertions that reference analytical_hessian to use hessian (or vice
versa) and delete the duplicate line containing analytical_hessian =
calc.get_hessian(atoms).
- Around line 66-69: After computing the Hessian with calc.get_hessian(atoms)
(variable hessian), add an assertion that the matrix is symmetric by comparing
it to its transpose using numpy allclose (e.g., assert np.allclose(hessian,
hessian.T, atol=1e-8) or a suitable tolerance) to catch asymmetry bugs in the
implementation.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/05_Vibrational_Analysis.ipynbtests/integration_tests/md/test_ase_hessian.pytests/unit_tests/model/test_hessian.py
✅ Files skipped from review due to trivial changes (1)
- examples/05_Vibrational_Analysis.ipynb
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit_tests/model/test_hessian.py
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
apax/md/ase_calc.py (1)
254-276:⚠️ Potential issue | 🟠 MajorReset
hessian_stepon reinitialize to avoid stale Hessians.When
initialize()is called due to atomic number or cell changes,self.hessian_stepcan still point at the old Z/periodicity, yielding incorrect Hessians.✅ Proposed fix
model = self._wrap_model(model) self.model = model + self.hessian_step = None self.step = get_step_fn( model, jnp.asarray(atoms.numbers), bool(np.any(atoms.cell.array > 1e-6)), self.neigbor_from_jax, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@apax/md/ase_calc.py` around lines 254 - 276, The initialize method may leave a stale Hessian reference in self.hessian_step when atoms (Z or cell) change; update initialize to reset self.hessian_step (e.g., set to None or 0) so any cached Hessian is invalidated on reinitialize; locate the initialize method and add a line to clear self.hessian_step before rebuilding model/step/neighbor_fn (symbols: initialize, self.hessian_step, _wrap_model, get_step_fn).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@apax/config/model_config.py`:
- Around line 186-187: The BaseModelConfig docstring is missing documentation
for the new user-facing flag calc_hessian; update the BaseModelConfig class
docstring's Parameters section to add an entry for calc_hessian (similar style
to calc_stress) describing its type (bool), default (False), and
purpose/behavior; reference the calc_hessian and calc_stress attributes in the
text so users understand the difference and any interactions or side effects.
---
Outside diff comments:
In `@apax/md/ase_calc.py`:
- Around line 254-276: The initialize method may leave a stale Hessian reference
in self.hessian_step when atoms (Z or cell) change; update initialize to reset
self.hessian_step (e.g., set to None or 0) so any cached Hessian is invalidated
on reinitialize; locate the initialize method and add a line to clear
self.hessian_step before rebuilding model/step/neighbor_fn (symbols: initialize,
self.hessian_step, _wrap_model, get_step_fn).
…s for hessian preds
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
apax/md/ase_calc.py (1)
319-365:⚠️ Potential issue | 🔴 Critical
initializedoes not resetself.hessian_step, leaving a stale closure when atoms change.
get_step_fncloses overZ = jnp.asarray(atoms.numbers). Wheninitializeis called again (e.g., because"numbers"or"cell"changed incalculate),self.stepis rebuilt with the newZ, butself.hessian_stepkeeps the old capturedZ. Every subsequent Hessian request then silently computes the Hessian for the wrong species.🐛 Proposed fix
def initialize(self, atoms): box = jnp.asarray(atoms.cell.array, dtype=jnp.float64) self.r_max = self.model_config.model.basis.r_max self.neigbor_from_jax = neighbor_calculable_with_jax(box, self.r_max) model, neighbor_fn = build_energy_neighbor_fns( ... ) model = self._wrap_model(model) self.model = model self.step = get_step_fn( model, jnp.asarray(atoms.numbers), bool(np.any(atoms.cell.array > 1e-6)), self.neigbor_from_jax, ) + self.hessian_step = None # invalidate stale Hessian closure self.neighbor_fn = neighbor_fn ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@apax/md/ase_calc.py` around lines 319 - 365, The initialize method rebuilds self.step using get_step_fn but does not reset/rebuild self.hessian_step, so any previously captured Z remains stale; fix by resetting or recomputing self.hessian_step inside initialize (e.g. set self.hessian_step = None or call get_step_fn to create a fresh hessian step using the new jnp.asarray(atoms.numbers) and same neighbor/flags) so Hessian calculations use the updated species vector; update the block that assigns self.step (inside initialize) to also update self.hessian_step accordingly.
♻️ Duplicate comments (1)
tests/integration_tests/md/test_ase_hessian.py (1)
19-21: Typo in variable name:model_confg_pathshould bemodel_config_path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration_tests/md/test_ase_hessian.py` around lines 19 - 21, There is a typo: the variable name model_confg_path should be model_config_path; rename model_confg_path to model_config_path everywhere it’s defined and referenced (e.g., the assignment using TEST_PATH and the subsequent with open(model_confg_path.as_posix(), "r") block) so all usages (opening the file, assertions, or helper calls) use the corrected identifier model_config_path.
🧹 Nitpick comments (1)
apax/md/ase_calc.py (1)
504-521: Double Hessian computation whencalc_hessian=True.When
ASECalculatoris created withcalc_hessian=True,initializebuilds the main model withcalc_hessian=True(viabuild_energy_neighbor_fns), soself.step(...)already includes"hessian"in its output. The subsequenthessian_stepcall then recomputes and overwrites the same value — doubling the cost of the most expensive operation.Consider skipping the
hessian_steppath when"hessian"is already present inresults:♻️ Proposed fix (JAX path; mirror for non-JAX)
results, self.neighbors = self.step(positions, self.neighbors, box) ... -if "hessian" in properties: +if "hessian" in properties and "hessian" not in results: if self.hessian_step is None: self._initialize_hessian(atoms) hessian_results, _ = self.hessian_step(positions, self.neighbors, box) results["hessian"] = hessian_results["hessian"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@apax/md/ase_calc.py` around lines 504 - 521, The code currently always calls self.hessian_step and overwrites results["hessian"] even when self.step already produced a "hessian" (causing double work); update both branches (the JAX branch where results = self.step(positions, ...) and the non-JAX branch) to skip invoking self.hessian_step when "hessian" is already present in results: check if "hessian" not in results before calling self._initialize_hessian(...) / self.hessian_step(...), and only set results["hessian"] from hessian_results when you actually ran hessian_step; keep use of self._initialize_hessian, self.hessian_step, self.step, and the existing signature/arguments unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@apax/md/simulate.py`:
- Around line 543-546: The new loop that sets force_variance by iterating
md_config.dynamics_checks can raise TypeError if dynamics_checks is None; update
the logic to check that md_config.dynamics_checks is truthy before iterating
(e.g., only run the for check in md_config.dynamics_checks when
md_config.dynamics_checks is not None/empty) so force_variance is determined
safely (preserve the existing initial assignment force_variance =
"forces_uncertainty" in md_config.properties and the loop that checks check.name
== "forces_uncertainty").
In `@docs/source/index.rst`:
- Line 6: Fix the stray period in the documentation sentence: locate the
sentence containing "`Apax` is a high-performance, extendable package for
training of and inference with atomistic neural networks., most prominently the
Gaussian Moment Neural Network model." and remove the errant '.' so it reads
"...atomistic neural networks, most prominently the Gaussian Moment Neural
Network model." ensuring the punctuation is correct and no extra characters
remain.
In `@docs/source/usage/property_prediction.rst`:
- Around line 34-39: The example incorrectly calls atoms.get_hessian()
(get_hessian is on the ASECalculator, not Atoms) and then passes a 2D (3N×3N)
matrix into VibrationsData.__init__ which expects a 4-D (N,3,N,3) Hessian; fix
by obtaining the Hessian from the calculator (call the ASECalculator instance's
get_hessian method) and then construct the vibrational data with
VibrationsData.from_2d(atoms, hessian) instead of VibrationsData(atoms,
hessian).
---
Outside diff comments:
In `@apax/md/ase_calc.py`:
- Around line 319-365: The initialize method rebuilds self.step using
get_step_fn but does not reset/rebuild self.hessian_step, so any previously
captured Z remains stale; fix by resetting or recomputing self.hessian_step
inside initialize (e.g. set self.hessian_step = None or call get_step_fn to
create a fresh hessian step using the new jnp.asarray(atoms.numbers) and same
neighbor/flags) so Hessian calculations use the updated species vector; update
the block that assigns self.step (inside initialize) to also update
self.hessian_step accordingly.
---
Duplicate comments:
In `@tests/integration_tests/md/test_ase_hessian.py`:
- Around line 19-21: There is a typo: the variable name model_confg_path should
be model_config_path; rename model_confg_path to model_config_path everywhere
it’s defined and referenced (e.g., the assignment using TEST_PATH and the
subsequent with open(model_confg_path.as_posix(), "r") block) so all usages
(opening the file, assertions, or helper calls) use the corrected identifier
model_config_path.
---
Nitpick comments:
In `@apax/md/ase_calc.py`:
- Around line 504-521: The code currently always calls self.hessian_step and
overwrites results["hessian"] even when self.step already produced a "hessian"
(causing double work); update both branches (the JAX branch where results =
self.step(positions, ...) and the non-JAX branch) to skip invoking
self.hessian_step when "hessian" is already present in results: check if
"hessian" not in results before calling self._initialize_hessian(...) /
self.hessian_step(...), and only set results["hessian"] from hessian_results
when you actually ran hessian_step; keep use of self._initialize_hessian,
self.hessian_step, self.step, and the existing signature/arguments unchanged.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
apax/md/ase_calc.pyapax/md/simulate.pyapax/nn/builder.pydocs/source/index.rstdocs/source/usage/index.rstdocs/source/usage/property_prediction.rsttests/integration_tests/md/test_ase_hessian.pytests/unit_tests/md/test_ase_calc.pytests/unit_tests/md/test_create_openmm_simulation.py
✅ Files skipped from review due to trivial changes (1)
- docs/source/usage/index.rst
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
apax/md/ase_calc.py (2)
119-122:⚠️ Potential issue | 🟠 Major
stress_ensembleis incorrectly overwritten with transposedforces_ensemble.The second
if "forces_ensemble"block (lines 119–122) has two problems:
- The condition checks
"forces_ensemble"instead of"stress_ensemble".- It transposes the already-transposed
forces_ensemble(mutated in lines 116–118) and assigns it tostress_ensemble, silently corrupting the per-model stress ensemble output.This pre-existing bug is now exercised on every code path that goes through
_wrap_modelfor a full ensemble.🐛 Proposed fix
- if "forces_ensemble" in ensemble.keys(): - ensemble["stress_ensemble"] = jnp.transpose( - ensemble["forces_ensemble"], (1, 2, 0) - ) + if "stress_ensemble" in ensemble.keys(): + ensemble["stress_ensemble"] = jnp.transpose( + ensemble["stress_ensemble"], (1, 2, 0) + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@apax/md/ase_calc.py` around lines 119 - 122, The bug is that the second conditional incorrectly checks for "forces_ensemble" and assigns a transposed forces array to stress_ensemble, corrupting per-model stress outputs; in _wrap_model, change the second if to check "stress_ensemble" (not "forces_ensemble") and ensure it assigns/constructs ensemble["stress_ensemble"] from the correct source (e.g., transpose the original per-model stress data or skip reusing forces_ensemble), so that ensemble["stress_ensemble"] is derived from the proper stress data rather than the already-transposed forces_ensemble.
329-343:⚠️ Potential issue | 🟡 MinorClarify design intent for
calc_hessianparameter in main model vs. hessian_step.The main model can emit Hessian in the forward pass depending on
calc_hessian(lines 173–175 inmodels.py). Whencalc_hessian=False(default), the main model is lightweight and_initialize_hessianlazily builds a separate model withcalc_hessian=Trueonly when Hessian is requested (guards at lines 505, 516). Whencalc_hessian=True, the main model itself produces Hessian, making_initialize_hessianunnecessary.Add a comment explaining this design choice: why users might set
calc_hessian=True(to compute Hessian in the main forward pass once and reuse it) vs. relying on lazy initialization (to keep the main model cheap by default and only pay Hessian cost when needed).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@apax/md/ase_calc.py` around lines 329 - 343, Add a short design-intent comment above the self.step / self.hessian_step setup in ase_calc.py explaining the semantics of the calc_hessian flag: when calc_hessian=True the main model (returned by _wrap_model / passed into get_step_fn) will compute and reuse the Hessian in its forward pass so _initialize_hessian is unnecessary, whereas the default calc_hessian=False keeps the main model lightweight and defers building a separate Hessian-capable model lazily via _initialize_hessian only when a Hessian is requested; mention that users should set calc_hessian=True to pay the Hessian cost once and reuse it, or leave it False to avoid Hessian overhead until needed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@apax/md/ase_calc.py`:
- Around line 165-175: The example uses a non-existent ASE Atoms method
atoms.get_hessian(); update the example to call the calculator method instead by
invoking ASECalculator.get_hessian (the calc.get_hessian method implemented
around the get_hessian definition) — e.g. replace atoms.get_hessian() with
calc.get_hessian(atoms) or calc.get_hessian(atoms=atoms) and adjust any
docstrings/comments that reference atoms.get_hessian() accordingly.
- Around line 266-269: Remove the unused local variable is_shallow_ensemble that
is assigned from self.model_config.model.ensemble and its kind check; delete the
three-line assignment (the is_shallow_ensemble name and its computation) from
the function so the code no longer defines an unused variable and Ruff F841 is
resolved. If the ensemble-kind check is needed elsewhere, replace usages with
direct checks against self.model_config.model.ensemble.kind instead of
introducing a new unused variable.
---
Outside diff comments:
In `@apax/md/ase_calc.py`:
- Around line 119-122: The bug is that the second conditional incorrectly checks
for "forces_ensemble" and assigns a transposed forces array to stress_ensemble,
corrupting per-model stress outputs; in _wrap_model, change the second if to
check "stress_ensemble" (not "forces_ensemble") and ensure it assigns/constructs
ensemble["stress_ensemble"] from the correct source (e.g., transpose the
original per-model stress data or skip reusing forces_ensemble), so that
ensemble["stress_ensemble"] is derived from the proper stress data rather than
the already-transposed forces_ensemble.
- Around line 329-343: Add a short design-intent comment above the self.step /
self.hessian_step setup in ase_calc.py explaining the semantics of the
calc_hessian flag: when calc_hessian=True the main model (returned by
_wrap_model / passed into get_step_fn) will compute and reuse the Hessian in its
forward pass so _initialize_hessian is unnecessary, whereas the default
calc_hessian=False keeps the main model lightweight and defers building a
separate Hessian-capable model lazily via _initialize_hessian only when a
Hessian is requested; mention that users should set calc_hessian=True to pay the
Hessian cost once and reuse it, or leave it False to avoid Hessian overhead
until needed.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
apax/md/ase_calc.pyapax/md/simulate.pydocs/source/index.rstdocs/source/usage/property_prediction.rsttests/integration_tests/md/test_ase_hessian.py
🚧 Files skipped from review as they are similar to previous changes (4)
- apax/md/simulate.py
- tests/integration_tests/md/test_ase_hessian.py
- docs/source/index.rst
- docs/source/usage/property_prediction.rst
9b766d1 to
6668cd6
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
| empirical_corrections: list[EmpiricalCorrection] = [] | ||
|
|
||
| calc_stress: bool = False | ||
| calc_hessian: bool = False |
There was a problem hiding this comment.
Not a critique, but have you considered:
output-properties:
- stress
- hessianwhich would also allow for future
output-propterties:
stress:
- analytical: true
...| return model | ||
|
|
||
| def _update_implemented_properties(self): | ||
| """ |
There was a problem hiding this comment.
I am not sure I fully understand this function, isn't the capability of the model defined at __init__ (except for ensembles maybe?). Why dynamic? Out of scope of the PR, just curious
| force_variance = "forces_uncertainty" in md_config.properties | ||
| if md_config.dynamics_checks: | ||
| for check in md_config.dynamics_checks: | ||
| if check.name == "forces_uncertainty": | ||
| force_variance = True |
There was a problem hiding this comment.
Isn't this true for all others as well, e.g. energy, forces, stresses, ... ?
| force_variance = self.config["ensemble"]["force_variance"] | ||
|
|
There was a problem hiding this comment.
potential key error if self.config["ensemble"] is None? Can this happen?
| positions = np.array( | ||
| [[0.0, 0.0, 0.0], [0.0, 0.7, 0.5], [0.0, -0.7, 0.5]], dtype=np.float64 | ||
| ) | ||
| atomic_numbers = np.array([8, 1, 1]) | ||
| box = np.array([0.0, 0.0, 0.0], dtype=np.float64) | ||
| atoms = Atoms(atomic_numbers, positions, cell=box) |
There was a problem hiding this comment.
extract into pytest.fixture?
| idx = jnp.array([[1, 2, 0, 2, 0, 1], [0, 0, 1, 1, 2, 2]]) | ||
| offsets = jnp.zeros((6, 3), dtype=jnp.float64) |
|
|
||
|
|
||
| def test_hessian_prediction(): | ||
| # Use float64 for better precision in Hessian |
There was a problem hiding this comment.
the test states use float64, how does the model interally treat the precision? Is this default / configurable?
| /*.py | ||
| *.dvc |
Summary by CodeRabbit
New Features
Documentation
Tests