-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdeep_rpo.py
180 lines (149 loc) · 9.53 KB
/
deep_rpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import datetime
import random
import sys
from pathlib import Path
import ast
import seaborn as sns
from networks import *
from training import *
from rpo import *
def deep_rpo_ssl(AD_name, dataset, nbr_modes, nbr_epochs, batchsize, lr_init, lr_decay, lr_milestones,
weight_decay, nbr_seeds, train_ratio, nbr_pulse_per_scan, nbr_targets, estimator, nbr_RPs,
date_string, supervision="AD", loss="deeprpo", arch="net0", nbr_modes_SAD=1, SAD_ratio=0.01):
print("\n {}".format(AD_name))
classes = [1, 2, 3, 4]
# Default device to 'cpu' if cuda is not available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
results_dicts_list = []
results_path = './results/' + dataset + '/' + AD_name + '/' + loss + '/{}/'.format(date_string)
for exp_index, seed in enumerate(range(nbr_seeds)):
print("\n seed: {}".format(seed))
# https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
g = torch.Generator()
g.manual_seed(seed)
normal_cls = np.random.choice(classes, nbr_modes, replace=False).tolist()
outlier_cls = [cls for cls in classes if cls not in normal_cls]
SAD_cls = np.random.choice(outlier_cls, nbr_modes_SAD, replace=False).tolist()
radar_dataset = simulated_radar_dataset(train_ratio, normal_cls, SAD_cls, SAD_ratio, batchsize, nbr_targets, nbr_pulse_per_scan)
radar_dataset.get_dataloaders(generator=g)
# choose which supervision setup is to be used in terms of training samples availability (WARNING: check the selected loss in training.py to fully understand how supervised an experiment is)
if supervision == "AD":
# vanilla setup, corresponding to Deep SVDD original paper: only normal training samples available
train_loader = radar_dataset.train_loader_sps_norm_2D
elif supervision == "SSL":
train_loader = radar_dataset.train_loader_sps_norm_ssl_2D
elif supervision == "SAD":
train_loader = radar_dataset.train_loader_sps_norm_sad_2D
elif supervision == "SAD+SSL":
train_loader = radar_dataset.train_loader_sps_norm_ssl_sad_2D
else:
raise ValueError("Supervision {} not implemented !".format(supervision))
complete_train_loader = radar_dataset.complete_train_loader_sps_norm_2D
val_loader = radar_dataset.val_loader_sps_norm_2D
test_loader = radar_dataset.test_loader_sps_norm_2D
learning_rate = lr_init
if arch=="net0":
net = SimuSPs_LeNet0(rep_dim=64).to(device)
elif arch=="net1":
net = SimuSPs_LeNet1(rep_dim=64).to(device)
elif arch=="net2":
net = SimuSPs_LeNet2(rep_dim=64).to(device)
elif arch=="net3":
net = SimuSPs_LeNet3(rep_dim=64).to(device)
elif arch=="net4":
net = SimuSPs_LeNet4(rep_dim=64).to(device)
elif arch=="net5":
net = SimuSPs_LeNet5(rep_dim=64).to(device)
elif arch=="net6":
net = SimuSPs_LeNet6(rep_dim=64).to(device)
elif arch == "net7":
net = SimuSPs_LeNet7(rep_dim=64).to(device)
else:
raise ValueError("Architecture {} not implemented !".format(arch))
torch.save(net.state_dict(), results_path + 'untrained_net_seed{}_normal{}.pt'.format(seed, normal_cls))
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=False)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=lr_decay)
RPO_ssl = Random_Projection_Outlyingness_deep_1D_ssl(nproj=nbr_RPs, unit_norm=True, device=device, estimator=estimator)
title = AD_name + " / " + "supervision:{}".format(supervision) + " / " + loss + " / " + "SAD_cls: {}".format(SAD_cls) + " / " + "SAD_ratio: {}".format(SAD_ratio)
get_2D_latent_representations_trainloader(train_loader, net, device, title + " / train data before training", figure_path=results_path, figure_name="latent2Dtrain_untrained_seed{}_normal{}.png".format(seed, normal_cls))
get_2D_latent_representations_testloader(test_loader, net, device, title + " / test data before training", figure_path=results_path, figure_name="latent2Dtest_untrained_seed{}_normal{}.png".format(seed, normal_cls))
epoch_losses, trainAUCs, valAUCs, testAUCs, test_scores, test_labels, net = training_deepRPO(train_loader, complete_train_loader, val_loader,
test_loader, normal_cls, net, device, RPO_ssl,
optimizer, scheduler, nbr_epochs, loss_name=loss)
torch.save(net.state_dict(), results_path + 'trained_net_seed{}_normal{}.pt'.format(seed, normal_cls))
###################################################################
best_AUC_epoch = int(np.argmax(valAUCs))
train_AUC = trainAUCs[best_AUC_epoch]
valid_AUC = valAUCs[best_AUC_epoch]
test_AUC = testAUCs[best_AUC_epoch]
###################################################################
fig, ax = plt.subplots(2, 2, figsize=(15, 15))
fig.suptitle(title)
ax[0, 0].plot(np.arange(len(epoch_losses)), epoch_losses)
ax[0, 0].set_xlabel("epoch")
ax[0, 0].title.set_text('Training loss')
ax[0, 1].scatter(np.arange(nbr_epochs+1), trainAUCs, c='g', label="train")
ax[0, 1].scatter(np.arange(nbr_epochs+1), valAUCs, c='b', label="valid")
ax[0, 1].scatter(np.arange(nbr_epochs+1), testAUCs, c='r', label="test")
ax[0, 1].set_xlabel("epoch")
ax[0, 1].legend()
ax[0, 1].title.set_text('AUCs during training')
sns.violinplot(x=test_labels[0], y=test_scores[0], ax=ax[1,0]).set(title='AD test scores before training')
ax[1, 0].set_xlabel("class idx")
sns.violinplot(x=test_labels[best_AUC_epoch], y=test_scores[best_AUC_epoch], ax=ax[1,1]).set(title='AD test scores after training')
ax[1, 1].set_xlabel("class idx")
plt.savefig(results_path + 'seed{}_normal{}.png'.format(seed, normal_cls))
get_2D_latent_representations_trainloader(train_loader, net, device, title + " / train data after training", figure_path=results_path, figure_name="latent2Dtrain_trained_seed{}_normal{}.png".format(seed, normal_cls))
get_2D_latent_representations_testloader(test_loader, net, device, title + " / test data after training", figure_path=results_path, figure_name="latent2Dtest_trained_seed{}_normal{}.png".format(seed, normal_cls))
###################################################################
results_dicts_list = store_comparison_results(results_dicts_list, train_AUC, valid_AUC, test_AUC, dataset,
AD_name, arch, supervision, loss, normal_cls, outlier_cls, SAD_cls,
SAD_ratio, nbr_modes, best_AUC_epoch, nbr_epochs, lr_init,
lr_milestones, batchsize, seed)
results_df = pd.DataFrame(results_dicts_list)
results_df = pd.concat([results_df], ignore_index=True)
results_df_path = results_path + 'experiment_results.csv'
results_df.to_csv(results_df_path)
return results_df
if __name__ == '__main__':
AD_name = str(sys.argv[1])
dataset = str(sys.argv[2])
nbr_modes = int(sys.argv[3])
nbr_epochs = int(sys.argv[4])
batchsize = int(sys.argv[5])
lr_init = float(sys.argv[6])
lr_decay = float(sys.argv[7])
lr_milestones = ast.literal_eval(sys.argv[8])
weight_decay = float(sys.argv[9])
nbr_seeds = int(sys.argv[10])
train_ratio = float(sys.argv[11])
nbr_pulse_per_scan = int(sys.argv[12])
nbr_targets = int(sys.argv[13])
estimator_RPs = sys.argv[14]
nbr_RPs = int(sys.argv[15])
supervision = str(sys.argv[16])
loss = str(sys.argv[17])
arch = str(sys.argv[18])
nbr_modes_SAD = int(sys.argv[19])
SAD_ratio = float(sys.argv[20])
experiment_name = str(sys.argv[21])
date_string = '{date:%Y-%m-%d_%H:%M:%S}'.format(date=datetime.datetime.now())
results_directory = './results/' + dataset + '/' + AD_name + '/' + loss + '/{}/'.format(date_string)
Path(results_directory).mkdir(parents=True, exist_ok=True)
Path('./comparison_graphs/').mkdir(parents=True, exist_ok=True)
results_df = deep_rpo_ssl(AD_name=AD_name, dataset=dataset, nbr_modes=nbr_modes, nbr_epochs=nbr_epochs, batchsize=batchsize,
lr_init=lr_init, lr_decay=lr_decay, lr_milestones=lr_milestones, weight_decay=weight_decay,
nbr_seeds=nbr_seeds, train_ratio=train_ratio, nbr_pulse_per_scan=nbr_pulse_per_scan,
nbr_targets=nbr_targets, estimator=estimator_RPs, nbr_RPs=nbr_RPs, date_string=date_string,
supervision=supervision, loss=loss, arch=arch, nbr_modes_SAD=nbr_modes_SAD, SAD_ratio=SAD_ratio)
# add results to last_results_df.csv for comparison graphs generation
global_results_df_file_path = 'comparison_graphs/last_results_df_{}.csv'.format(experiment_name)
global_results_df = pd.read_csv(global_results_df_file_path, index_col=False)
frames = [global_results_df, results_df]
global_results_df = pd.concat(frames, ignore_index=True)
global_results_df.to_csv(global_results_df_file_path, index=False)