Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cagrikymk committed Feb 11, 2024
1 parent 4671a3f commit 68c6e46
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ jaxreaxff --init_FF Datasets/cobalt/ffield_lit \
```
**5-** To have the GPU support, jaxlib with CUDA support needs to be installed, otherwise the code can only run on CPUs.
```
# install jaxlib-0.3.0 with Python 3.8, CUDA-11 and cuDNN-8.05 support
pip install -U "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
You can learn more about JAX installation here: [JAX install guide](https://github.com/google/jax#installation)<br>
Expand Down
5 changes: 3 additions & 2 deletions jaxreaxff/helper_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,17 @@ def loss_function(force_field, structure, nbr_lists,
charge_loss = 0
if use_forces:
atom_mask = structure.atom_types >= 0
(energy_vals, charges), forces = jax.vmap(jax.value_and_grad(calculate_energy_and_charges),
(energy_vals, charges), forces = jax.vmap(jax.value_and_grad(calculate_energy_and_charges, has_aux=True),
(0,0,0,None))(structure.positions, structure, nbr_lists, force_field)
forces = forces * atom_mask[:,:, jnp.newaxis]
force_err = (forces - structure.target_f) ** 2
force_err = force_err * atom_mask[:,:, jnp.newaxis]
force_loss = jnp.sum(force_err/(structure.atom_count.reshape(-1,1,1) * 3)) * force_w
else:
energy_vals, charges = jax.vmap(calculate_energy_and_charges,
(0,0,0,None))(structure.positions, structure, nbr_lists, force_field)
if use_charges:
charge_err = (charges - structure.target_ch) ** 2
charge_err = charge_err * atom_mask
charge_loss = jnp.sum(charge_err/structure.atom_count.reshape(-1,1)) * charge_w

target_energies = structure.target_e
Expand Down

0 comments on commit 68c6e46

Please sign in to comment.