forked from TrustAI/testRNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·84 lines (73 loc) · 3.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import time
import sys
sys.path.append('example')
sys.path.append('src')
from utils import mkdir, delete_folder
from sentimentTestSuite import sentimentTrainModel, sentimentGenerateTestSuite
from mnistTestSuite_adv_test import mnist_lstm_train, mnist_lstm_adv_test
from mnistTestSuite_backdoor_test import mnist_lstm_backdoor_test
from lipoTestSuite import lipo_lstm_train, lipo_lstm_test
from ucf101_vgg16_lstm_TestSuite import vgg16_lstm_train, vgg16_lstm_test
from record import record
import re
def main():
parser = argparse.ArgumentParser(description='testing for recurrent neural networks')
parser.add_argument('--model', dest='modelName', choices=['mnist', 'sentiment', 'lipo', 'ucf101'], default='sentiment')
parser.add_argument('--TestCaseNum', dest='TestCaseNum', default='10000')
parser.add_argument('--Mutation', dest='Mutation', choices=['random', 'genetic'], default='random')
parser.add_argument('--CoverageStop', dest='CoverageStop', default='0.9')
parser.add_argument('--threshold_SC', dest='threshold_SC', default='0.6')
parser.add_argument('--threshold_BC', dest='threshold_BC', default='0.8')
parser.add_argument('--symbols_TC', dest='symbols_TC', default='3')
parser.add_argument('--seq', dest='seq', default='[70,89]')
parser.add_argument('--mode', dest='mode', choices=['train', 'test'], default='test')
parser.add_argument('--output', dest='filename', default='./log_folder/record.txt', help='')
args=parser.parse_args()
# seq:
# mnist [4,24]
# sentiment [400,499]
# lipo [60,79]
# ucf101 [0,10]
modelName = args.modelName
mode = args.mode
filename = args.filename
threshold_SC = args.threshold_SC
threshold_BC = args.threshold_BC
symbols_TC = args.symbols_TC
seq = args.seq
seq = re.findall(r"\d+\.?\d*", seq)
Mutation = args.Mutation
CoverageStop = args.CoverageStop
TestCaseNum = args.TestCaseNum
# record time
r = record(filename,time.time())
if modelName == 'sentiment':
if mode == 'train':
sentimentTrainModel()
else:
sentimentGenerateTestSuite(r,threshold_SC,threshold_BC,symbols_TC,seq,TestCaseNum, Mutation, CoverageStop)
elif modelName == 'mnist':
if mode == 'train':
mnist_lstm_train()
elif mode == 'backdoor':
mnist_lstm_backdoor_test(r,threshold_SC,threshold_BC,symbols_TC,seq,TestCaseNum, Mutation, CoverageStop)
else:
mnist_lstm_adv_test(r, threshold_SC, threshold_BC, symbols_TC, seq, TestCaseNum, Mutation, CoverageStop)
elif modelName == 'lipo':
if mode == 'train':
lipo_lstm_train()
else:
lipo_lstm_test(r,threshold_SC,threshold_BC,symbols_TC,seq,TestCaseNum, Mutation, CoverageStop)
elif modelName == 'ucf101':
if mode == 'train':
vgg16_lstm_train()
else:
vgg16_lstm_test(r, threshold_SC, threshold_BC, symbols_TC, seq, TestCaseNum, Mutation, CoverageStop)
else:
print("Please specify a model from {sentiment, mnist, lipo, ucf101}")
r.close()
if __name__ == "__main__":
start_time = time.time()
main()
print("--- %s seconds ---" % (time.time() - start_time))