-
Notifications
You must be signed in to change notification settings - Fork 9
/
mutation_analysis.py
56 lines (49 loc) · 2.01 KB
/
mutation_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import fasta, preprocessing, model, evaluate
import sys, os, argparse, re
'''Usage: mutation_analysis.py <fasta> <model weights file> <output file name>'''
'''Options:
-o Overwrite output file (default: error if file already exists)
-e Extract Ensembl transcript ID from fasta header (default: use full header)'''
def main():
# Options
parser = argparse.ArgumentParser()
parser.add_argument('fasta', help = '''Fasta file of sequences for truncation.''')
parser.add_argument('weights', help = '''File containing model weights. This specifies which model to use.''')
parser.add_argument('output', help = '''File where results will be written. By default, this script will
not run if the output file already exists. Use the -o option to overwrite an existing file.''')
parser.add_argument('-o', help = '''Use this option if you want to overwrite an existing output file.''',
action = 'store_true')
parser.add_argument('-e', help = '''Use this option to write just the Ensembl transcript ID instead of
the full header.''', action = 'store_true')
args = parser.parse_args()
##########
## MAIN ##
##########
lookup = dict(zip(range(5), 'NATCG'))
if not args.o and os.path.exists(args.output):
raise Exception(args.output + ' already exists!')
if args.e:
transpat = re.compile('ENST\d*.\d*')
print "Reading input files..."
full_seqs = fasta.load_fasta(args.fasta, 0)
mut = []
for seq, name in full_seqs:
for i in xrange(len(seq)):
for b in xrange(1, 5):
if b != seq[i]:
mut_seq = seq[:i] + [b] + seq[i + 1:]
mut.append((mut_seq, str(i), lookup[b], name))
mRNN = model.build_model(args.weights)
lines = []
print "Evaluating sequences..."
seqs, pos, base, name = zip(*mut)
if args.e:
name = [transpat.search(n).group() for n in name]
scores = mRNN.batch_predict(seqs, True)
scores = map(str, scores)
lines = zip(name, pos, base, scores)
lines = ['\t'.join(line) for line in lines]
with open(args.output, 'w') as out:
out.write('\n'.join(lines))
if __name__ == "__main__":
main()