Skip to content

Commit ea4e273

Browse files
committed
Merge branch 'dev' into bal
2 parents 63e3981 + 079608f commit ea4e273

File tree

4 files changed

+140
-41
lines changed

4 files changed

+140
-41
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ pip install --upgrade pip
3939

4040
CUDA 12 installation. Wheels only available on linux.
4141
```bash
42-
pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
42+
pip install --upgrade "jax[cuda12_pip]==0.4.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
4343
```
4444

4545
CUDA 11 installation. Wheels only available on linux.
4646
```bash
47-
pip install --upgrade "jax[cuda11_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
47+
pip install --upgrade "jax[cuda11_pip]==0.4.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
4848
```
4949

5050
See the [Jax installation instructions](https://github.com/google/jax#installation) for more details.

apax/md/ase_calc.py

Lines changed: 136 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ase.calculators.calculator import Calculator, all_changes
99
from flax.traverse_util import flatten_dict
1010
from jax_md import partition, quantity, space
11+
from matscipy.neighbours import neighbour_list
1112

1213
from apax.model import ModelBuilder
1314
from apax.train.checkpoints import restore_parameters
@@ -28,32 +29,41 @@ def maybe_vmap(apply, params, Z):
2829
return energy_fn
2930

3031

31-
def build_energy_neighbor_fns(atoms, config, params, dr_threshold):
32+
def build_energy_neighbor_fns(atoms, config, params, dr_threshold, neigbor_from_jax):
33+
r_max = config.model.r_max
3234
atomic_numbers = jnp.asarray(atoms.numbers)
33-
box = jnp.asarray(atoms.get_cell().array, dtype=jnp.float32)
35+
box = jnp.asarray(atoms.cell.array, dtype=jnp.float64)
36+
neigbor_from_jax = neighbor_calculable_with_jax(box, r_max)
3437
box = box.T
35-
36-
if np.all(box < 1e-6):
37-
displacement_fn, _ = space.free()
38-
else:
39-
displacement_fn, _ = space.periodic_general(box, fractional_coordinates=True)
38+
displacement_fn = None
39+
neighbor_fn = None
40+
41+
if neigbor_from_jax:
42+
if np.all(box < 1e-6):
43+
displacement_fn, _ = space.free()
44+
else:
45+
displacement_fn, _ = space.periodic_general(box, fractional_coordinates=True)
46+
47+
neighbor_fn = jax_md_reduced.partition.neighbor_list(
48+
displacement_fn,
49+
box,
50+
config.model.r_max,
51+
dr_threshold,
52+
fractional_coordinates=True,
53+
disable_cell_list=True,
54+
format=partition.Sparse,
55+
)
4056

4157
Z = jnp.asarray(atomic_numbers)
4258
n_species = 119 # int(np.max(Z) + 1)
4359
builder = ModelBuilder(config.model.get_dict(), n_species=n_species)
60+
4461
model = builder.build_energy_derivative_model(
4562
apply_mask=True, init_box=np.array(box), inference_disp_fn=displacement_fn
4663
)
64+
4765
energy_fn = maybe_vmap(model.apply, params, Z)
48-
neighbor_fn = jax_md_reduced.partition.neighbor_list(
49-
displacement_fn,
50-
box,
51-
config.model.r_max,
52-
dr_threshold,
53-
fractional_coordinates=True,
54-
disable_cell_list=True,
55-
format=partition.Sparse,
56-
)
66+
5767
return energy_fn, neighbor_fn
5868

5969

@@ -62,7 +72,7 @@ def process_stress(results, box):
6272
results = {
6373
# We should properly check whether CP2K uses the ASE cell convention
6474
# for tetragonal strain, it doesn't matter whether we transpose or not
65-
k: (val.T / V if k.startswith("stress") else val)
75+
k: val.T / V if k.startswith("stress") else val
6676
for k, val in results.items()
6777
}
6878
return results
@@ -83,7 +93,6 @@ def ensemble(positions, Z, idx, box, offsets):
8393
class ASECalculator(Calculator):
8494
"""
8595
ASE Calculator for apax models.
86-
DOES NOT SUPPORT CUTOFFS LARGER THAN MIN(BOX SIZE / 2)!
8796
"""
8897

8998
implemented_properties = [
@@ -96,6 +105,7 @@ def __init__(
96105
model_dir: Union[Path, list[Path]],
97106
dr_threshold: float = 0.5,
98107
transformations: Callable = [],
108+
padding_factor: float = 1.5,
99109
**kwargs
100110
):
101111
Calculator.__init__(self, **kwargs)
@@ -105,6 +115,7 @@ def __init__(
105115
self.n_models = 1 if isinstance(model_dir, (Path, str)) else len(model_dir)
106116

107117
self.model_config, self.params = restore_parameters(model_dir)
118+
self.padding_factor = padding_factor
108119

109120
if self.model_config.model.calc_stress:
110121
self.implemented_properties.append("stress")
@@ -119,13 +130,18 @@ def __init__(
119130
self.step = None
120131
self.neighbor_fn = None
121132
self.neighbors = None
133+
self.offsets = None
122134

123135
def initialize(self, atoms):
136+
box = jnp.asarray(atoms.cell.array, dtype=jnp.float64)
137+
self.r_max = self.model_config.model.r_max
138+
self.neigbor_from_jax = neighbor_calculable_with_jax(box, self.r_max)
124139
model, neighbor_fn = build_energy_neighbor_fns(
125140
atoms,
126141
self.model_config,
127142
self.params,
128143
self.dr_threshold,
144+
self.neigbor_from_jax,
129145
)
130146

131147
if self.is_ensemble:
@@ -134,7 +150,99 @@ def initialize(self, atoms):
134150
for transformation in self.transformations:
135151
model = transformation.apply(model, self.n_models)
136152

137-
Z = jnp.asarray(atoms.numbers)
153+
self.step = get_step_fn(model, atoms, self.neigbor_from_jax)
154+
self.neighbor_fn = neighbor_fn
155+
156+
def set_neighbours_and_offsets(self, atoms, box):
157+
idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms, self.r_max)
158+
159+
if len(idxs_i) > self.padded_length:
160+
print("neighbor list overflowed, reallocating.")
161+
self.padded_length = int(len(idxs_i) * self.padding_factor)
162+
self.initialize(atoms)
163+
164+
zeros_to_add = self.padded_length - len(idxs_i)
165+
166+
self.neighbors = np.array([idxs_i, idxs_j], dtype=np.int32)
167+
self.neighbors = np.pad(self.neighbors, ((0, 0), (0, zeros_to_add)), "constant")
168+
169+
offsets = np.matmul(offsets, box)
170+
self.offsets = np.pad(offsets, ((0, zeros_to_add), (0, 0)), "constant")
171+
172+
def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
173+
Calculator.calculate(self, atoms, properties, system_changes)
174+
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
175+
box = jnp.asarray(atoms.cell.array, dtype=jnp.float64)
176+
177+
# setup model and neighbours
178+
if self.step is None:
179+
self.initialize(atoms)
180+
181+
if self.neigbor_from_jax:
182+
self.neighbors = self.neighbor_fn.allocate(positions)
183+
else:
184+
idxs_i = neighbour_list("i", atoms, self.r_max)
185+
self.padded_length = int(len(idxs_i) * self.padding_factor)
186+
187+
elif "numbers" in system_changes:
188+
self.initialize(atoms)
189+
190+
if self.neigbor_from_jax:
191+
self.neighbors = self.neighbor_fn.allocate(positions)
192+
193+
elif "cell" in system_changes:
194+
neigbor_from_jax = neighbor_calculable_with_jax(box, self.r_max)
195+
if self.neigbor_from_jax != neigbor_from_jax:
196+
self.initialize(atoms)
197+
198+
# predict
199+
if self.neigbor_from_jax:
200+
results, self.neighbors = self.step(positions, self.neighbors, box)
201+
202+
if self.neighbors.did_buffer_overflow:
203+
print("neighbor list overflowed, reallocating.")
204+
self.initialize(atoms)
205+
self.neighbors = self.neighbor_fn.allocate(positions)
206+
207+
results, self.neighbors = self.step(positions, self.neighbors, box)
208+
209+
else:
210+
self.set_neighbours_and_offsets(atoms, box)
211+
positions = np.array(space.transform(np.linalg.inv(box), atoms.positions))
212+
213+
results = self.step(positions, self.neighbors, box, self.offsets)
214+
215+
self.results = {k: np.array(v, dtype=np.float64) for k, v in results.items()}
216+
self.results["energy"] = self.results["energy"].item()
217+
218+
219+
def neighbor_calculable_with_jax(box, r_max):
220+
if np.all(box < 1e-6):
221+
return True
222+
else:
223+
# all lettice vector combinations to calculate all three plane distances
224+
a_vec_list = [box[0], box[0], box[1]]
225+
b_vec_list = [box[1], box[2], box[2]]
226+
c_vec_list = [box[2], box[1], box[0]]
227+
228+
height = []
229+
for i in range(3):
230+
normvec = np.cross(a_vec_list[i], b_vec_list[i])
231+
projection = (
232+
c_vec_list[i]
233+
- np.sum(normvec * c_vec_list[i]) / np.sum(normvec**2) * normvec
234+
)
235+
height.append(np.linalg.norm(c_vec_list[i] - projection))
236+
237+
if np.min(height) / 2 > r_max:
238+
return True
239+
else:
240+
return False
241+
242+
243+
def get_step_fn(model, atoms, neigbor_from_jax):
244+
Z = jnp.asarray(atoms.numbers)
245+
if neigbor_from_jax:
138246

139247
@jax.jit
140248
def step_fn(positions, neighbor, box):
@@ -145,33 +253,24 @@ def step_fn(positions, neighbor, box):
145253
neighbor = neighbor.update(positions, box=box)
146254
else:
147255
neighbor = neighbor.update(positions)
148-
offsets = jnp.full([neighbor.idx.shape[1], 3], 0)
149256

257+
offsets = jnp.full([neighbor.idx.shape[1], 3], 0)
150258
results = model(positions, Z, neighbor.idx, box, offsets)
151259

152260
if "stress" in results.keys():
153261
results = process_stress(results, box)
154262

155263
return results, neighbor
156264

157-
self.step = step_fn
158-
self.neighbor_fn = neighbor_fn
265+
else:
159266

160-
def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
161-
Calculator.calculate(self, atoms, properties, system_changes)
162-
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
163-
box = jnp.asarray(atoms.cell.array, dtype=jnp.float64)
164-
if self.step is None or "numbers" in system_changes:
165-
self.initialize(atoms)
166-
self.neighbors = self.neighbor_fn.allocate(positions)
267+
@jax.jit
268+
def step_fn(positions, neighbor, box, offsets):
269+
results = model(positions, Z, neighbor, box, offsets)
167270

168-
results, self.neighbors = self.step(positions, self.neighbors, box)
271+
if "stress" in results.keys():
272+
results = process_stress(results, box)
169273

170-
if self.neighbors.did_buffer_overflow:
171-
print("neighbor list overflowed, reallocating.")
172-
self.initialize(atoms)
173-
self.neighbors = self.neighbor_fn.allocate(positions)
174-
results, self.neighbors = self.step(positions, self.neighbors, box)
274+
return results
175275

176-
self.results = {k: np.array(v, dtype=np.float64) for k, v in results.items()}
177-
self.results["energy"] = self.results["energy"].item()
276+
return step_fn

apax/md/function_transformations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class GaussianAcceleratedMolecularDynamics(FunctionTransformation):
6767
https://pubs.acs.org/doi/10.1021/acs.jctc.5b00436
6868
6969
Parameters:
70-
energy_target: Target potential energy below wich to apply the boost potential.
70+
energy_target: Target potential energy below which to apply the boost potential.
7171
spring_constant: Spring constant of the boost potential.
7272
"""
7373

apax/optimizer/get_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def map_nested_fn(fn: Callable[[str, Any], dict]) -> Callable[[dict], dict]:
1717

1818
def map_fn(nested_dict):
1919
return {
20-
k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
20+
k: map_fn(v) if isinstance(v, dict) else fn(k, v)
2121
for k, v in nested_dict.items()
2222
}
2323

0 commit comments

Comments
 (0)