From 68c6e4603f4cb913107c5ec470c92b86bde66080 Mon Sep 17 00:00:00 2001 From: Cagri Kaymak Date: Sun, 11 Feb 2024 05:43:31 -0500 Subject: [PATCH] minor fixes --- README.md | 1 - jaxreaxff/helper_v2.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ff2fe9c..d77c2c5 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,6 @@ jaxreaxff --init_FF Datasets/cobalt/ffield_lit \ ``` **5-** To have the GPU support, jaxlib with CUDA support needs to be installed, otherwise the code can only run on CPUs. ``` -# install jaxlib-0.3.0 with Python 3.8, CUDA-11 and cuDNN-8.05 support pip install -U "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` You can learn more about JAX installation here: [JAX install guide](https://github.com/google/jax#installation)
diff --git a/jaxreaxff/helper_v2.py b/jaxreaxff/helper_v2.py index e4b27a9..0e641b6 100644 --- a/jaxreaxff/helper_v2.py +++ b/jaxreaxff/helper_v2.py @@ -202,16 +202,17 @@ def loss_function(force_field, structure, nbr_lists, charge_loss = 0 if use_forces: atom_mask = structure.atom_types >= 0 - (energy_vals, charges), forces = jax.vmap(jax.value_and_grad(calculate_energy_and_charges), + (energy_vals, charges), forces = jax.vmap(jax.value_and_grad(calculate_energy_and_charges, has_aux=True), (0,0,0,None))(structure.positions, structure, nbr_lists, force_field) - forces = forces * atom_mask[:,:, jnp.newaxis] force_err = (forces - structure.target_f) ** 2 + force_err = force_err * atom_mask[:,:, jnp.newaxis] force_loss = jnp.sum(force_err/(structure.atom_count.reshape(-1,1,1) * 3)) * force_w else: energy_vals, charges = jax.vmap(calculate_energy_and_charges, (0,0,0,None))(structure.positions, structure, nbr_lists, force_field) if use_charges: charge_err = (charges - structure.target_ch) ** 2 + charge_err = charge_err * atom_mask charge_loss = jnp.sum(charge_err/structure.atom_count.reshape(-1,1)) * charge_w target_energies = structure.target_e