-
Notifications
You must be signed in to change notification settings - Fork 9
/
mRNN_ensemble.py
133 lines (119 loc) · 5.01 KB
/
mRNN_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import fasta, preprocessing, model, evaluate
import sys, os, getopt
#########
# USAGE #
#########
'''
Prints the usage statement and all options
'''
def usage():
script = os.path.basename(__file__)
print "\n\nUsage: " + script + " [options] <fasta file>"
print('''
Options:
-h --help\t\tprints this help message.
-o --output\t\tthe file-base for the output files.
-w --weights\tcomma-separated list of pkl files of the model/model weights.
-E --epochs\tNumber of epochs to train on.(default=100)
-b --batch_size\tbatch size for testing (default=64)
-e --embedding_size\tNumber of dimensions in embedding (default=256)
-r --recurrent_gate_size\tSize of recurrent gate (default=512)
-d --dropout\tThe dropout probability p_dropout (default=0.4)
-t --test\tProportion of data to test on. (default=0.1)
-l --min_length\tminimum length sequence to train on (default=200)
-L --max_length\tmaximum length sequence to train on (default=1000)
-f --file_label\tA text label on the accuracy output files.
''')
sys.exit()
#########
# MAIN #
#########
'''
The main loop. Parse input options, run training sequence.
'''
def main():
# Options
opts, files = getopt.getopt(sys.argv[1:], "hvo:w:E:b:e:r:d:t:p:f:", ["help",
"output=",
"weights=",
"epochs=",
"batch_size=",
"embedding_size=",
"recurrent_gate_size=",
"dropout=",
"test=",
"min_length=",
"max_length=",
"file_label=",
])
if len(files) != 1:
usage()
fastaFile = files[0]
print "using fasta file: ", fastaFile
# Defaults:
parameters = {}
parameters['output'] = None
parameters['verbose'] = False
parameters['weights'] = None
parameters['batch_size'] = 16
parameters['embedding_size'] = 128
parameters['recurrent_gate_size'] = 256
parameters['dropout'] = 0.1
parameters['test'] = 0.1
parameters['min_length'] = 200
parameters['max_length'] = 1000
parameters['num_train'] = 10000
parameters['epochs'] = 50
parameters['save_freq'] = 3
parameters['file_label'] = ""
# loop over options:
for option, argument in opts:
if option == "-v":
parameters[verbose] = True
elif option in ("-h", "--help"):
usage()
elif option in ("-o", "--output"):
parameters['output'] = argument
elif option in ("-w", "--weights"):
parameters['weights'] = argument
elif option in ("-E", "--epochs"):
parameters['epochs'] = int(argument)
elif option in ("-b", "--batch_size"):
parameters['batch_size'] = int(argument)
elif option in ("-e", "--embedding_size"):
parameters['embedding_size'] = int(argument)
elif option in ("-d", "--dropout"):
parameters['dropout'] = float(argument)
elif option in ("-t", "--test"):
parameters['test'] = float(argument)
elif option in ("-l", "--min_length"):
parameters['min_length'] = int(argument)
elif option in ("-L", "--max_length"):
parameters['max_length'] = int(argument)
elif option in ("-n", "--num_train"):
parameters['num_train'] = int(argument)
elif option in ("-f", "--file_label"):
parameters['file_label'] = argument
else:
assert False, "unhandled option"
##########
## MAIN ##
##########
print "Reading input files..."
sequences = fasta.load_fasta(fastaFile,parameters['min_length'])
if not parameters['weights']:
print "No weights given with -w parameter.\n"
sys.exit()
modelFiles = parameters['weights'].split(',')
models = []
for modelFile in modelFiles:
print "Building model..."
mRNN = model.build_model(modelFile,parameters['embedding_size'],parameters['recurrent_gate_size'],5,parameters['dropout'])
models.append(mRNN)
print "Evaluating sequences..."
output = fastaFile + ".mRNNensemble"
if parameters['output']:
output = parameters['output']
evaluate.ensemble_evaluate_sequences(models, sequences, output, parameters['batch_size'])
if __name__ == "__main__":
main()