Skip to content

Commit 557434a

Browse files
committed
sketch of ensemble energy fn
1 parent 8bacd0e commit 557434a

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

apax/md/nvt.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@
2525
log = logging.getLogger(__name__)
2626

2727

28+
def make_energy_ensemble(model):
29+
def ensemble(positions, Z, idx, box, offsets):
30+
energies = model(positions, Z, idx, box, offsets)
31+
energy = jnp.mean(energies)
32+
33+
return energy
34+
35+
return ensemble
36+
37+
2838
def heights_of_box_sides(box):
2939
heights = []
3040

@@ -350,7 +360,7 @@ def md_setup(model_config: Config, md_config: MDConfig):
350360
model.apply,
351361
params,
352362
Z=system.atomic_numbers,
353-
box=system.box,
363+
box=system.box, # TODO IS THIS CORRECT FOR NPT???
354364
offsets=jnp.array([0.0, 0.0, 0.0]),
355365
)
356366
sim_fns = SimulationFunctions(energy_fn, shift_fn, neighbor_fn)

0 commit comments

Comments
 (0)