-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathproxy_nca.py
139 lines (120 loc) · 4.86 KB
/
proxy_nca.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from deep_net.googlenet import *
from deep_net.bn_inception import *
from tqdm import *
from pnca_loss import *
from pdb import set_trace as breakpoint
import eval_dataset
from evaluate import *
import os
import argparse
def save_dict(path, whichDataset, losses_list, train_recall_list, val_recall_list , train_nmi_list, val_nmi_list, best_recall ):
info_dict= {}
info_dict['losses'] = losses_list
info_dict['train_recall'] = train_recall_list
info_dict['val_recall'] = val_recall_list
info_dict['train_nmi'] = train_nmi_list
info_dict['val_nmi'] = val_nmi_list
info_dict['best_recall'] = best_recall
if(not(os.path.exists(path) ) ):
os.mkdir(path)
torch.save(info_dict, path+'/'+whichDataset+'_info_dict_n_pair.log')
parser = argparse.ArgumentParser(description=' Code')
parser.add_argument('--dataset', default='cars', help = 'Training dataset, e.g. cub, cars, SOP')
args = parser.parse_args()
embed_size = 512
num_epochs = 30
lr = 5e-4
fc_lr = 5e-4
weight_decay = 1e-4
lr_decay_step = 10
lr_decay_gamma = 0.5
test_interval = 5
n_pair_l2_reg = 0.001
ALLOWED_MINING_OPS = ['npair']
REQUIRES_BATCHMINER = True
REQUIRES_OPTIM = False
whichDataset = args.dataset#'cub'#'cub' # Choose from cub, cars, or SOP (works if you downloaded data using datasets.py)
if(whichDataset =='cub'):
n_classes = 100
elif(whichDataset == 'cars'):
n_classes = 98
elif(whichDataset == 'SOP'):
n_classes = 11318
else:
print('Specify correct dataset name')
quit()
save_model_dict_path = f'./pnca_model_dict_{whichDataset}.pt'
info_save_path = './results'
trainset = eval_dataset.load(name=whichDataset,
root='./data/'+whichDataset.upper()+'/',
mode='train',
transform = eval_dataset.utils.make_transform())
train_loader = torch.utils.data.DataLoader(trainset, batch_size = 100,
shuffle = True, num_workers = 8, drop_last = True)
testset= eval_dataset.load(name=whichDataset,
root='./data/'+whichDataset.upper()+'/',
mode='eval',
transform = eval_dataset.utils.make_transform(is_train=False))
test_loader = torch.utils.data.DataLoader(testset, batch_size =100,
shuffle=False, num_workers=8, pin_memory=True,
drop_last=False)
dev = "cuda" if torch.cuda.is_available() else "cpu"
# CAN ADD MORE MODELS
model = bn_inception(embedding_size=embed_size)#googlenet_metric(embed_size=embed_size, dropout_val = 0.4)
model.to(dev)
model_params = [
{'params': list(set(model.parameters()).difference(set(model.model.fc.parameters())))},
{'params': model.model.fc.parameters(), 'lr':float(fc_lr) }]
criterion_pnca = pnca_loss(n_classes, embed_size)
criterion_pnca.to(dev)
model_params.append({'params': criterion_pnca.parameters(), 'lr':float(lr)})
optim = torch.optim.Adam(model_params, lr=float(lr), weight_decay = weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size= lr_decay_step, gamma = lr_decay_gamma)
losses_list = []
train_recall_list=[]
val_recall_list = []
train_nmi_list=[]
val_nmi_list = []
best_recall = -1
#breakpoint()
for epoch in range(num_epochs):
for batch_idx , (images, labels) in enumerate(tqdm(train_loader)):
model.train()
embed_image = model(images.to(dev))
loss = criterion_pnca(embed_image, labels.to(dev))
#breakpoint()
optim.zero_grad()
loss.backward()
optim.step()
if(batch_idx %1000 ==0):
losses_list.append(loss.item())
print("Loss:", loss.item())
scheduler.step()
model.eval()
if(epoch % test_interval == 0):
if(whichDataset == 'SOP'):
recall, nmi = get_recall_SOP(model, test_loader )
nmi = 0
else:
recall, nmi = get_recall_and_NMI(model, test_loader )
val_recall_list.append(recall)
val_nmi_list.append(nmi)
if(whichDataset == 'SOP'):
train_recall, train_nmi = get_recall_SOP(model, train_loader )
train_nmi = 0
else:
train_recall, train_nmi = get_recall_and_NMI(model, train_loader )
train_recall_list.append(train_recall)
train_nmi_list.append(train_nmi)
save_dict(info_save_path, whichDataset, losses_list, train_recall_list, val_recall_list , train_nmi_list, val_nmi_list, best_recall )
torch.save(model.state_dict(), save_model_dict_path)
torch.save(model.state_dict(), save_model_dict_path)
print('\n\nFor the final model found:')
recall, nmi = get_recall_and_NMI(model, test_loader )
val_recall_list.append(recall)
val_nmi_list.append(nmi)
train_recall, train_nmi = get_recall_and_NMI(model, train_loader )
train_recall_list.append(train_recall)
train_nmi_list.append(train_nmi)