8
8
from ase .calculators .calculator import Calculator , all_changes
9
9
from flax .traverse_util import flatten_dict
10
10
from jax_md import partition , quantity , space
11
+ from matscipy .neighbours import neighbour_list
11
12
12
13
from apax .model import ModelBuilder
13
14
from apax .train .checkpoints import restore_parameters
@@ -28,32 +29,41 @@ def maybe_vmap(apply, params, Z):
28
29
return energy_fn
29
30
30
31
31
- def build_energy_neighbor_fns (atoms , config , params , dr_threshold ):
32
+ def build_energy_neighbor_fns (atoms , config , params , dr_threshold , neigbor_from_jax ):
33
+ r_max = config .model .r_max
32
34
atomic_numbers = jnp .asarray (atoms .numbers )
33
- box = jnp .asarray (atoms .get_cell ().array , dtype = jnp .float32 )
35
+ box = jnp .asarray (atoms .cell .array , dtype = jnp .float64 )
36
+ neigbor_from_jax = neighbor_calculable_with_jax (box , r_max )
34
37
box = box .T
35
-
36
- if np .all (box < 1e-6 ):
37
- displacement_fn , _ = space .free ()
38
- else :
39
- displacement_fn , _ = space .periodic_general (box , fractional_coordinates = True )
38
+ displacement_fn = None
39
+ neighbor_fn = None
40
+
41
+ if neigbor_from_jax :
42
+ if np .all (box < 1e-6 ):
43
+ displacement_fn , _ = space .free ()
44
+ else :
45
+ displacement_fn , _ = space .periodic_general (box , fractional_coordinates = True )
46
+
47
+ neighbor_fn = jax_md_reduced .partition .neighbor_list (
48
+ displacement_fn ,
49
+ box ,
50
+ config .model .r_max ,
51
+ dr_threshold ,
52
+ fractional_coordinates = True ,
53
+ disable_cell_list = True ,
54
+ format = partition .Sparse ,
55
+ )
40
56
41
57
Z = jnp .asarray (atomic_numbers )
42
58
n_species = 119 # int(np.max(Z) + 1)
43
59
builder = ModelBuilder (config .model .get_dict (), n_species = n_species )
60
+
44
61
model = builder .build_energy_derivative_model (
45
62
apply_mask = True , init_box = np .array (box ), inference_disp_fn = displacement_fn
46
63
)
64
+
47
65
energy_fn = maybe_vmap (model .apply , params , Z )
48
- neighbor_fn = jax_md_reduced .partition .neighbor_list (
49
- displacement_fn ,
50
- box ,
51
- config .model .r_max ,
52
- dr_threshold ,
53
- fractional_coordinates = True ,
54
- disable_cell_list = True ,
55
- format = partition .Sparse ,
56
- )
66
+
57
67
return energy_fn , neighbor_fn
58
68
59
69
@@ -62,7 +72,7 @@ def process_stress(results, box):
62
72
results = {
63
73
# We should properly check whether CP2K uses the ASE cell convention
64
74
# for tetragonal strain, it doesn't matter whether we transpose or not
65
- k : ( val .T / V if k .startswith ("stress" ) else val )
75
+ k : val .T / V if k .startswith ("stress" ) else val
66
76
for k , val in results .items ()
67
77
}
68
78
return results
@@ -83,7 +93,6 @@ def ensemble(positions, Z, idx, box, offsets):
83
93
class ASECalculator (Calculator ):
84
94
"""
85
95
ASE Calculator for apax models.
86
- DOES NOT SUPPORT CUTOFFS LARGER THAN MIN(BOX SIZE / 2)!
87
96
"""
88
97
89
98
implemented_properties = [
@@ -96,6 +105,7 @@ def __init__(
96
105
model_dir : Union [Path , list [Path ]],
97
106
dr_threshold : float = 0.5 ,
98
107
transformations : Callable = [],
108
+ padding_factor : float = 1.5 ,
99
109
** kwargs
100
110
):
101
111
Calculator .__init__ (self , ** kwargs )
@@ -105,6 +115,7 @@ def __init__(
105
115
self .n_models = 1 if isinstance (model_dir , (Path , str )) else len (model_dir )
106
116
107
117
self .model_config , self .params = restore_parameters (model_dir )
118
+ self .padding_factor = padding_factor
108
119
109
120
if self .model_config .model .calc_stress :
110
121
self .implemented_properties .append ("stress" )
@@ -119,13 +130,18 @@ def __init__(
119
130
self .step = None
120
131
self .neighbor_fn = None
121
132
self .neighbors = None
133
+ self .offsets = None
122
134
123
135
def initialize (self , atoms ):
136
+ box = jnp .asarray (atoms .cell .array , dtype = jnp .float64 )
137
+ self .r_max = self .model_config .model .r_max
138
+ self .neigbor_from_jax = neighbor_calculable_with_jax (box , self .r_max )
124
139
model , neighbor_fn = build_energy_neighbor_fns (
125
140
atoms ,
126
141
self .model_config ,
127
142
self .params ,
128
143
self .dr_threshold ,
144
+ self .neigbor_from_jax ,
129
145
)
130
146
131
147
if self .is_ensemble :
@@ -134,7 +150,99 @@ def initialize(self, atoms):
134
150
for transformation in self .transformations :
135
151
model = transformation .apply (model , self .n_models )
136
152
137
- Z = jnp .asarray (atoms .numbers )
153
+ self .step = get_step_fn (model , atoms , self .neigbor_from_jax )
154
+ self .neighbor_fn = neighbor_fn
155
+
156
+ def set_neighbours_and_offsets (self , atoms , box ):
157
+ idxs_i , idxs_j , offsets = neighbour_list ("ijS" , atoms , self .r_max )
158
+
159
+ if len (idxs_i ) > self .padded_length :
160
+ print ("neighbor list overflowed, reallocating." )
161
+ self .padded_length = int (len (idxs_i ) * self .padding_factor )
162
+ self .initialize (atoms )
163
+
164
+ zeros_to_add = self .padded_length - len (idxs_i )
165
+
166
+ self .neighbors = np .array ([idxs_i , idxs_j ], dtype = np .int32 )
167
+ self .neighbors = np .pad (self .neighbors , ((0 , 0 ), (0 , zeros_to_add )), "constant" )
168
+
169
+ offsets = np .matmul (offsets , box )
170
+ self .offsets = np .pad (offsets , ((0 , zeros_to_add ), (0 , 0 )), "constant" )
171
+
172
+ def calculate (self , atoms , properties = ["energy" ], system_changes = all_changes ):
173
+ Calculator .calculate (self , atoms , properties , system_changes )
174
+ positions = jnp .asarray (atoms .positions , dtype = jnp .float64 )
175
+ box = jnp .asarray (atoms .cell .array , dtype = jnp .float64 )
176
+
177
+ # setup model and neighbours
178
+ if self .step is None :
179
+ self .initialize (atoms )
180
+
181
+ if self .neigbor_from_jax :
182
+ self .neighbors = self .neighbor_fn .allocate (positions )
183
+ else :
184
+ idxs_i = neighbour_list ("i" , atoms , self .r_max )
185
+ self .padded_length = int (len (idxs_i ) * self .padding_factor )
186
+
187
+ elif "numbers" in system_changes :
188
+ self .initialize (atoms )
189
+
190
+ if self .neigbor_from_jax :
191
+ self .neighbors = self .neighbor_fn .allocate (positions )
192
+
193
+ elif "cell" in system_changes :
194
+ neigbor_from_jax = neighbor_calculable_with_jax (box , self .r_max )
195
+ if self .neigbor_from_jax != neigbor_from_jax :
196
+ self .initialize (atoms )
197
+
198
+ # predict
199
+ if self .neigbor_from_jax :
200
+ results , self .neighbors = self .step (positions , self .neighbors , box )
201
+
202
+ if self .neighbors .did_buffer_overflow :
203
+ print ("neighbor list overflowed, reallocating." )
204
+ self .initialize (atoms )
205
+ self .neighbors = self .neighbor_fn .allocate (positions )
206
+
207
+ results , self .neighbors = self .step (positions , self .neighbors , box )
208
+
209
+ else :
210
+ self .set_neighbours_and_offsets (atoms , box )
211
+ positions = np .array (space .transform (np .linalg .inv (box ), atoms .positions ))
212
+
213
+ results = self .step (positions , self .neighbors , box , self .offsets )
214
+
215
+ self .results = {k : np .array (v , dtype = np .float64 ) for k , v in results .items ()}
216
+ self .results ["energy" ] = self .results ["energy" ].item ()
217
+
218
+
219
+ def neighbor_calculable_with_jax (box , r_max ):
220
+ if np .all (box < 1e-6 ):
221
+ return True
222
+ else :
223
+ # all lettice vector combinations to calculate all three plane distances
224
+ a_vec_list = [box [0 ], box [0 ], box [1 ]]
225
+ b_vec_list = [box [1 ], box [2 ], box [2 ]]
226
+ c_vec_list = [box [2 ], box [1 ], box [0 ]]
227
+
228
+ height = []
229
+ for i in range (3 ):
230
+ normvec = np .cross (a_vec_list [i ], b_vec_list [i ])
231
+ projection = (
232
+ c_vec_list [i ]
233
+ - np .sum (normvec * c_vec_list [i ]) / np .sum (normvec ** 2 ) * normvec
234
+ )
235
+ height .append (np .linalg .norm (c_vec_list [i ] - projection ))
236
+
237
+ if np .min (height ) / 2 > r_max :
238
+ return True
239
+ else :
240
+ return False
241
+
242
+
243
+ def get_step_fn (model , atoms , neigbor_from_jax ):
244
+ Z = jnp .asarray (atoms .numbers )
245
+ if neigbor_from_jax :
138
246
139
247
@jax .jit
140
248
def step_fn (positions , neighbor , box ):
@@ -145,33 +253,24 @@ def step_fn(positions, neighbor, box):
145
253
neighbor = neighbor .update (positions , box = box )
146
254
else :
147
255
neighbor = neighbor .update (positions )
148
- offsets = jnp .full ([neighbor .idx .shape [1 ], 3 ], 0 )
149
256
257
+ offsets = jnp .full ([neighbor .idx .shape [1 ], 3 ], 0 )
150
258
results = model (positions , Z , neighbor .idx , box , offsets )
151
259
152
260
if "stress" in results .keys ():
153
261
results = process_stress (results , box )
154
262
155
263
return results , neighbor
156
264
157
- self .step = step_fn
158
- self .neighbor_fn = neighbor_fn
265
+ else :
159
266
160
- def calculate (self , atoms , properties = ["energy" ], system_changes = all_changes ):
161
- Calculator .calculate (self , atoms , properties , system_changes )
162
- positions = jnp .asarray (atoms .positions , dtype = jnp .float64 )
163
- box = jnp .asarray (atoms .cell .array , dtype = jnp .float64 )
164
- if self .step is None or "numbers" in system_changes :
165
- self .initialize (atoms )
166
- self .neighbors = self .neighbor_fn .allocate (positions )
267
+ @jax .jit
268
+ def step_fn (positions , neighbor , box , offsets ):
269
+ results = model (positions , Z , neighbor , box , offsets )
167
270
168
- results , self .neighbors = self .step (positions , self .neighbors , box )
271
+ if "stress" in results .keys ():
272
+ results = process_stress (results , box )
169
273
170
- if self .neighbors .did_buffer_overflow :
171
- print ("neighbor list overflowed, reallocating." )
172
- self .initialize (atoms )
173
- self .neighbors = self .neighbor_fn .allocate (positions )
174
- results , self .neighbors = self .step (positions , self .neighbors , box )
274
+ return results
175
275
176
- self .results = {k : np .array (v , dtype = np .float64 ) for k , v in results .items ()}
177
- self .results ["energy" ] = self .results ["energy" ].item ()
276
+ return step_fn
0 commit comments