-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathschnet.py
134 lines (101 loc) · 3.65 KB
/
schnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import schnetpack as spk
from schnetpack.datasets import QM9
qm9tut = './qm9tut'
if not os.path.exists('qm9tut'):
os.makedirs(qm9tut)
qm9data = QM9('./qm9.db', load_only=[QM9.U0], remove_uncharacterized=True) # already pre-downloaded
train, val, test = spk.train_test_split(
data=qm9data,
num_train=100,
num_val=10,
split_file=os.path.join(qm9tut, "split.npz"),
)
print(len(test))
train_loader = spk.AtomsLoader(train, batch_size=50, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=1)
for batch in train_loader:
print(batch[QM9.U0])
atomrefs = qm9data.get_atomref(QM9.U0)
print('U0 of hyrogen:', '{:.2f}'.format(atomrefs[QM9.U0][1][0]), 'eV')
print('U0 of carbon:', '{:.2f}'.format(atomrefs[QM9.U0][6][0]), 'eV')
print('U0 of oxygen:', '{:.2f}'.format(atomrefs[QM9.U0][8][0]), 'eV')
# check energy:
print(qm9data.get_properties(0)[1]["energy_U0"])
means, stddevs = train_loader.get_statistics(
QM9.U0, divide_by_atoms=True, single_atom_ref=atomrefs
)
print('Mean atomization energy / atom:', means[QM9.U0])
print('Std. dev. atomization energy / atom:', stddevs[QM9.U0])
schnet = spk.representation.SchNet(
n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=5,
cutoff=4., cutoff_network=spk.nn.cutoff.CosineCutoff
)
output_U0 = spk.atomistic.Atomwise(n_in=30, atomref=atomrefs[QM9.U0], property=QM9.U0,
mean=means[QM9.U0], stddev=stddevs[QM9.U0])
model = spk.AtomisticModel(representation=schnet, output_modules=output_U0)
# optimization:
from torch.optim import Adam
# loss function
def mse_loss(batch, result):
diff = batch[QM9.U0]-result[QM9.U0]
err_sq = torch.mean(diff ** 2)
return err_sq
# build optimizer
optimizer = Adam(model.parameters(), lr=1e-2)
import schnetpack.train as trn
loss = trn.build_mse_loss([QM9.U0])
metrics = [spk.metrics.MeanAbsoluteError(QM9.U0)]
hooks = [
trn.CSVHook(log_path=qm9tut, metrics=metrics),
trn.ReduceLROnPlateauHook(
optimizer,
patience=5, factor=0.8, min_lr=1e-6,
stop_after_min=True
)
]
trainer = trn.Trainer(
model_path=qm9tut,
model=model,
hooks=hooks,
loss_fn=loss,
optimizer=optimizer,
train_loader=train_loader,
validation_loader=val_loader,
)
device = "cpu" # change to 'cpu' if gpu is not available
n_epochs = 200 # takes about 10 min on a notebook GPU. reduces for playing around
trainer.train(device=device, n_epochs=n_epochs)
# plot training:
import numpy as np
import matplotlib.pyplot as plt
from ase.units import kcal, mol
results = np.loadtxt(os.path.join(qm9tut, 'log.csv'), skiprows=1, delimiter=',')
time = results[:,0]-results[0,0]
learning_rate = results[:,1]
train_loss = results[:,2]
val_loss = results[:,3]
val_mae = results[:,4]
print('Final validation MAE:', np.round(val_mae[-1], 2), 'eV =',
np.round(val_mae[-1] / (kcal/mol), 2), 'kcal/mol')
# test model:
import torch
best_model = torch.load(os.path.join(qm9tut, 'best_model'))
test_loader = spk.AtomsLoader(test, batch_size=100)
err = 0
print(len(test_loader))
for count, batch in enumerate(test_loader):
# move batch to GPU, if necessary
batch = {k: v.to(device) for k, v in batch.items()}
# apply model
pred = best_model(batch)
# calculate absolute error
tmp = torch.sum(torch.abs(pred[QM9.U0]-batch[QM9.U0]))
tmp = tmp.detach().cpu().numpy() # detach from graph & convert to numpy
err += tmp
# log progress
percent = '{:3.2f}'.format(count/len(test_loader)*100)
print('Progress:', percent+'%'+' '*(5-len(percent)), end="\r")
err /= len(test)
print('Test MAE', np.round(err, 2), 'eV =',
np.round(err / (kcal/mol), 2), 'kcal/mol')