forked from maxhodak/keras-molecules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_gen.py
90 lines (75 loc) · 3.5 KB
/
train_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from __future__ import print_function
import argparse
import os
import h5py
import numpy as np
import pandas as pd
from molecules.vectorizer import SmilesDataGenerator
NUM_EPOCHS = 1
EPOCH_SIZE = 500000
BATCH_SIZE = 500
LATENT_DIM = 292
MAX_LEN = 120
TEST_SPLIT = 0.20
RANDOM_SEED = 1337
def get_arguments():
parser = argparse.ArgumentParser(description='Molecular autoencoder network')
parser.add_argument('data', type=str, help='The HDF5 file containing structures.')
parser.add_argument('model', type=str,
help='Where to save the trained model. If this file exists, it will be opened and resumed.')
parser.add_argument('--epochs', type=int, metavar='N', default=NUM_EPOCHS,
help='Number of epochs to run during training.')
parser.add_argument('--latent_dim', type=int, metavar='N', default=LATENT_DIM,
help='Dimensionality of the latent representation.')
parser.add_argument('--batch_size', type=int, metavar='N', default=BATCH_SIZE,
help='Number of samples to process per minibatch during training.')
parser.add_argument('--epoch_size', type=int, metavar='N', default=EPOCH_SIZE,
help='Number of samples to process per epoch during training.')
parser.add_argument('--test_split', type=float, metavar='N', default=TEST_SPLIT,
help='Fraction of dataset to use as test data, rest is training data.')
parser.add_argument('--random_seed', type=int, metavar='N', default=RANDOM_SEED,
help='Seed to use to start randomizer for shuffling.')
return parser.parse_args()
def main():
args = get_arguments()
np.random.seed(args.random_seed)
from molecules.model import MoleculeVAE
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
data = pd.read_hdf(args.data, 'table')
structures = data['structure']
# import gzip
# filepath = args.data
# structures = [line.split()[0].strip() for line in gzip.open(filepath) if line]
# can also use CanonicalSmilesDataGenerator
datobj = SmilesDataGenerator(structures, MAX_LEN,
test_split=args.test_split,
random_seed=args.random_seed)
test_divisor = int((1 - datobj.test_split) / (datobj.test_split))
train_gen = datobj.train_generator(args.batch_size)
test_gen = datobj.test_generator(args.batch_size)
# reformulate generators to not use weights
train_gen = ((tens, tens) for (tens, _, weights) in train_gen)
test_gen = ((tens, tens) for (tens, _, weights) in test_gen)
model = MoleculeVAE()
if os.path.isfile(args.model):
model.load(datobj.chars, args.model, latent_rep_size = args.latent_dim)
else:
model.create(datobj.chars, latent_rep_size = args.latent_dim)
checkpointer = ModelCheckpoint(filepath = args.model,
verbose = 1,
save_best_only = True)
reduce_lr = ReduceLROnPlateau(monitor = 'val_loss',
factor = 0.2,
patience = 3,
min_lr = 0.0001)
model.autoencoder.fit_generator(
train_gen,
args.epoch_size,
nb_epoch = args.epochs,
callbacks = [checkpointer, reduce_lr],
validation_data = test_gen,
nb_val_samples = args.epoch_size / test_divisor,
pickle_safe = True
)
if __name__ == '__main__':
main()