-
Notifications
You must be signed in to change notification settings - Fork 9
/
train_mRNN.py
138 lines (128 loc) · 5.49 KB
/
train_mRNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import fasta, preprocessing, model, evaluate
import sys, os, getopt
import inspect
#########
# USAGE #
#########
'''
Prints the usage statement and all options
'''
def usage():
script = os.path.basename(__file__)
print "\n\nUsage: " + script + " [options] <positive fasta> <negative fasta> <positive validation fasta> <negative validation fasta>"
print('''
Options:
-h --help\t\tprints this help message.
-o --output\t\tthe file-base for the output files.
-w --weights\tpkl file of the model/model weights.
-E --epochs\tNumber of epochs to train on.(default=50)
-b --batch_size\tbatch size for testing (default=64)
-e --embedding_size\tNumber of dimensions in embedding (default=128)
-r --recurrent_gate_size\tSize of recurrent gate (default=256)
-d --dropout\tThe dropout probability p_dropout (default=0.5)
-t --test\tProportion of data to test on. (default=0.1)
-l --min_length\tminimum length sequence to train on (default=200)
-L --max_length\tmaximum length sequence to train on (default=1000)
-s --early_stopping\tNumber of epochs above minimum validation score before stopping
''')
sys.exit()
#########
# MAIN #
#########
'''
The main loop. Parse input options, run training sequence.
'''
def main():
# Options
opts, files = getopt.getopt(sys.argv[1:], "hvo:w:E:b:e:r:d:t:l:L:s:", ["help",
"output=",
"weights=",
"epochs=",
"batch_size=",
"embedding_size=",
"recurrent_gate_size=",
"dropout=",
"test=",
"min_length=",
"max_length=",
"early_stopping="
])
if len(files) != 4:
usage()
posFastaFile = files[0]
negFastaFile = files[1]
posValFasta = files[2]
negValFasta = files[3]
print "using positive file: ", posFastaFile
print "using negative file: ", negFastaFile
print "using positive validation file: ", posValFasta
print "using negative validation file: ", negValFasta
# Defaults:
parameters = {}
parameters['output'] = None
parameters['verbose'] = False
parameters['weights'] = None
parameters['batch_size'] = 16
parameters['embedding_size'] = 128
parameters['recurrent_gate_size'] = 256
parameters['dropout'] = 0.5
parameters['test'] = 0.1
parameters['min_length'] = 200
parameters['max_length'] = 1000
parameters['num_train'] = 10000
parameters['epochs'] = 25
parameters['save_freq'] = 1
parameters['early_stopping'] = None
# loop over options:
for option, argument in opts:
if option == "-v":
parameters[verbose] = True
elif option in ("-h", "--help"):
usage()
elif option in ("-o", "--output"):
parameters['output'] = argument
elif option in ("-w", "--weights"):
parameters['weights'] = argument
elif option in ("-E", "--epochs"):
parameters['epochs'] = int(argument)
elif option in ("-b", "--batch_size"):
parameters['batch_size'] = int(argument)
elif option in ("-e", "--embedding_size"):
parameters['embedding_size'] = int(argument)
elif option in ("-r", "--recurrent_gate_size"):
parameters['recurrent_gate_size'] = int(argument)
elif option in ("-d", "--dropout"):
parameters['dropout'] = float(argument)
elif option in ("-t", "--test"):
parameters['test'] = float(argument)
elif option in ("-l", "--min_length"):
parameters['min_length'] = int(argument)
elif option in ("-L", "--max_length"):
parameters['max_length'] = int(argument)
elif option in ("-n", "--num_train"):
parameters['num_train'] = int(argument)
elif option in ("-s", "--early_stopping"):
if argument is not None:
argument = int(argument)
parameters['early_stopping'] = argument
else:
assert False, "unhandled option"
##########
## MAIN ##
##########
print "Reading input files..."
positives = fasta.load_fasta(posFastaFile,parameters['min_length'])
negatives = fasta.load_fasta(negFastaFile,parameters['min_length'])
valpos = fasta.load_fasta(posValFasta,parameters['min_length'])
valneg = fasta.load_fasta(negValFasta,parameters['min_length'])
train = positives,negatives
val = valpos, valneg
print "Building new model..."
mRNN = model.build_model(parameters['weights'],parameters['embedding_size'],parameters['recurrent_gate_size'],5,parameters['dropout'])
print inspect.getmodule(mRNN.__class__)
print "Training model..."
mRNN = model.train_model(mRNN, train, val, parameters['epochs'], parameters['output'],parameters['max_length'],parameters['save_freq'],
parameters['early_stopping'])
return mRNN
if __name__ == "__main__":
main()