@@ -171,23 +171,32 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
171
171
atom_ener_coeff = tf .reshape (atom_ener_coeff , tf .shape (atom_ener ))
172
172
energy = tf .reduce_sum (atom_ener_coeff * atom_ener , 1 )
173
173
if self .has_e :
174
- # Yufan: put extra weight on the last 7 energy components
175
- ener_diff = energy - energy_hat # TODO Yufan multiple the weight
176
- ener_diff_reshape = tf .reshape (ener_diff , [- 1 ], name = 'ener_diff_reshape' )
177
- ener_diff_reshape = tf .concat ([ener_diff_reshape [:- 7 ], ener_diff_reshape [- 7 :] * 10 ], axis = 0 ) # scale the last 7 force by 10
178
- l2_ener_loss = tf .reduce_mean ( tf .square (ener_diff ), name = 'l2_' + suffix ) # energy loss
174
+ # # Yufan: put extra weight on the last 7 energy components
175
+ # ener_diff = energy - energy_hat # TODO Yufan multiple the weight
176
+ # ener_diff_log = tf.convert_to_tensor(ener_diff, name='ener_diff')
177
+ # ener_diff_reshape = tf.reshape(ener_diff, [-1], name='ener_diff_reshape')
178
+ # tf.print("Ener_diff_reshape: ", ener_diff_reshape)
179
+ # tf.summary.histogram("ener_diff_reshape", ener_diff_reshape)
180
+
181
+ # ener_diff_reshape_shape = tf.strings.format("{}", tf.shape(ener_diff_reshape))
182
+ # tf.summary.text("ener_diff_reshape_shape", ener_diff_reshape_shape)
183
+ # ener_diff_log_shape = tf.strings.format("{}", tf.shape(ener_diff_log))
184
+ # tf.summary.text("ener_diff_log_shape", ener_diff_log_shape)
185
+
186
+ # ener_diff_reshape = tf.concat([ener_diff_reshape[:-7], ener_diff_reshape[-7:] * 10], axis=0) # scale the last 7 force by 10
187
+ # l2_ener_loss = tf.reduce_mean( tf.square(ener_diff), name='l2_'+suffix) # energy loss
179
188
180
- print ("*************************************************************************************************" )
181
- print ("******************** Warning: the last 7 energy components are scaled by 10 ***************" )
182
- print ("*************************************************************************************************" )
189
+ # print("*************************************************************************************************")
190
+ # print("******************** Warning: the last 7 energy components are scaled by 10 ***************")
191
+ # print("*************************************************************************************************")
183
192
184
- # print shape of energy and energy_hat
185
- print (f"energy shape: { energy .shape } " )
186
- print (f"energy_hat shape: { energy_hat .shape } " )
193
+ # # print shape of energy and energy_hat
194
+ # print(f"energy shape: {energy.shape}")
195
+ # print(f"energy_hat shape: {energy_hat.shape}")
187
196
188
- # l2_ener_loss = tf.reduce_mean(
189
- # tf.square(energy - energy_hat), name="l2_" + suffix
190
- # )
197
+ l2_ener_loss = tf .reduce_mean (
198
+ tf .square (energy - energy_hat ), name = "l2_" + suffix
199
+ )
191
200
192
201
if self .has_f or self .has_pf or self .relative_f or self .has_gf :
193
202
force_reshape = tf .reshape (force , [- 1 ])
@@ -204,6 +213,16 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
204
213
if self .has_f :
205
214
# Yufan: put extra weight on the last 7 force components
206
215
diff_f_reshape = tf .reshape (diff_f , [- 1 ], name = "diff_f_reshape" )
216
+ diff_f_log = tf .convert_to_tensor (diff_f , name = 'diff_f' )
217
+ tf .print ("Diff_f_reshape: " , diff_f_reshape )
218
+ tf .summary .histogram ("diff_f_reshape" , diff_f_reshape )
219
+
220
+ diff_f_reshape_shape = tf .strings .format ("{}" , tf .shape (diff_f_reshape ))
221
+ tf .summary .text ("diff_f_reshape_shape" , diff_f_reshape_shape )
222
+ diff_f_log_shape = tf .strings .format ("{}" , tf .shape (diff_f_log ))
223
+ tf .summary .text ("diff_f_log_shape" , diff_f_log_shape )
224
+
225
+
207
226
diff_f_reshape = tf .concat ([diff_f_reshape [:- 7 * 3 ], diff_f_reshape [- 7 * 3 :] * 10 ], axis = 0 ) # scale the last 7 force by 10
208
227
l2_force_loss = tf .reduce_mean (tf .square (diff_f_reshape ), name = "l2_force_" + suffix ) # force loss
209
228
0 commit comments