-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
168 lines (139 loc) · 7.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import tqdm
import wandb
import torch
import argparse
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from common.meter import Meter
from common.utils import compute_accuracy, set_seed, restart_from_checkpoint
from models.dataloader.samplers import CategoriesSampler
from models.dataloader.data_utils import dataset_builder
from models.dataloader.aux_dataloader import get_aux_dataloader
from models.renet import DCANet
from test import test_main, evaluate
from utils import rotrate_concat, record_data
from common.utils import pprint, ensure_path, set_gpu
from loss import AdaptivePrototypicalLoss # Import the new loss class
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
os.environ["CUDA_LAUNCH_BLOCKING"] = '2'
def parse_args():
parser = argparse.ArgumentParser(description='train')
''' about dataset '''
parser.add_argument('-dataset', type=str, default='miniImageNet',
choices=['miniImageNet', 'cub', 'tieredImageNet', 'CIFAR-FS', 'FC100'])
parser.add_argument('-data_root', type=str, default='/home/lxj/new_main/dataset', help='dir of datasets')
parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
''' about training specs '''
parser.add_argument('-batch', type=int, default=64, help='auxiliary batch size')
parser.add_argument('-temperature', type=float, default=0.2, metavar='tau', help='temperature for metric-based loss')
parser.add_argument('-lamb', type=float, default=0.25, metavar='lambda', help='loss balancing term')
parser.add_argument('--w_d', type=float, default=0.01, help='weight of distance loss')
parser.add_argument('--w_p', type=float, default=0.5)
''' about training schedules '''
parser.add_argument('-max_epoch', type=int, default=80, help='max epoch to run')
parser.add_argument('-lr', type=float, default=0.1, help='learning rate')
parser.add_argument('-gamma', type=float, default=0.05, help='learning rate decay factor')
parser.add_argument('-milestones', nargs='+', type=int, default=[60, 70], help='milestones for MultiStepLR')
parser.add_argument('--save_freq', type=int, default=5, help='save frequency')
parser.add_argument('-save_all', action='store_true', help='save models on each epoch')
parser.add_argument('-use_resume', action='store_true', help='use the result of training before')
parser.add_argument('--resume_file', type=str, default='epoch_10.pth')
''' about few-shot episodes '''
parser.add_argument('-way', type=int, default=5, metavar='N', help='number of few-shot classes')
parser.add_argument('-shot', type=int, default=1, metavar='K', help='number of shots')
parser.add_argument('-query', type=int, default=15, help='number of query image per class')
parser.add_argument('-val_episode', type=int, default=200, help='number of validation episode')
parser.add_argument('-test_episode', type=int, default=2000, help='number of testing episodes after training')
parser.add_argument('-proto_size', type=int, default=100, help='the number of dynamic prototypes')
parser.add_argument('--crop_size', type=int, default=84)
parser.add_argument('--trans', type=int, default=4, help='number of transformations')
parser.add_argument('--hidden_size', type=int, default=320, help='hidden size for cross attention layer')
parser.add_argument('--feat_dim', type=int, default=640)
parser.add_argument('--sup_t', type=float, default=0.2)
''' about CoDA '''
parser.add_argument('-temperature_attn', type=float, default=2, metavar='gamma', help='temperature for softmax in computing cross-attention')
''' about env '''
parser.add_argument('-gpu', default='2', help='the GPU ids e.g. \"0\", \"0,1\", \"0,1,2\", etc')
parser.add_argument('-test_tag', type=str, default='test_wp0.5_wd0.01_lam0.25_t2', help='extra dir name added to checkpoint dir')
parser.add_argument('-seed', type=int, default=1, help='random seed')
parser.add_argument('-wandb', action='store_true', help='not plotting learning curve on wandb',
) # train: enable logging / test: disable logging
args = parser.parse_args()
pprint(vars(args))
torch.set_printoptions(linewidth=100)
args.num_gpu = set_gpu(args)
args.device_ids = None if args.gpu == '-1' else list(range(args.num_gpu))
args.save_path = os.path.join(f'checkpoints/{args.dataset}/{args.shot}shot-{args.way}way/', args.test_tag)
ensure_path(args.save_path)
if not args.wandb:
wandb.init(project=f'renet-{args.dataset}-{args.way}w{args.shot}s',
config=args,
save_code=True,
name=args.test_tag)
if args.dataset == 'miniImageNet':
args.num_class = 64
elif args.dataset == 'cub':
args.num_class = 100
elif args.dataset == 'FC100':
args.num_class = 60
elif args.dataset == 'tieredImageNet':
args.num_class = 351
elif args.dataset == 'CIFAR-FS':
args.num_class = 64
args.crop_size = 42
elif args.dataset == 'cars':
args.num_class = 130
elif args.dataset == 'dogs':
args.num_class = 70
return args
def train(epoch, model, loader, optimizer, criterion, args=None):
model.train()
train_loader = loader['train_loader']
train_loader_aux = loader['train_loader_aux']
# Label for query set, always in the same pattern
query_label = torch.arange(args.way).repeat(args.query).cuda() # 012340123401234...
loss_meter = Meter()
acc_meter = Meter()
k = args.way * args.shot
tqdm_gen = tqdm.tqdm(train_loader)
# Initialize the AdaptivePrototypicalLoss with required parameters
adaptive_prototypical_loss = AdaptivePrototypicalLoss(args)
for i, ((data, train_labels), (data_aux, train_labels_aux)) in enumerate(zip(tqdm_gen, train_loader_aux), 1):
data, train_labels = data.cuda(), train_labels.cuda()
data_aux = data_aux.cuda()
batch_size = data_aux.size(0)
data_aux = rotrate_concat([data_aux])
train_labels_aux = train_labels_aux.repeat(args.trans).cuda()
# Forward images (3, 84, 84) -> (C, H, W)
model.module.mode = 'encoder'
data, fea_loss, cst_loss, dis_loss = model(data)
data_aux = model(data_aux, aux=True) # I prefer to separate feed-forwarding data and data_aux due to BN
# Extract features
model.module.mode = 'coda'
data_shot, data_query = data[:k], data[k:]
logits, absolute_logits = model((data_shot.unsqueeze(0).repeat(1, 1, 1, 1, 1), data_query))
# Compute the adaptive prototypical loss
features = model.get_features(data_shot, data_query) # Assuming a method to extract features
loss = adaptive_prototypical_loss(features, query_label)
# Loss for auxiliary batch
model.module.mode = 'fc'
logits_global, logits_eq = model(data_aux)
loss_aux = F.cross_entropy(logits_global, train_labels_aux)
proxy_labels = torch.zeros(args.trans * batch_size).cuda().long()
for ii in range(args.trans):
proxy_labels[ii * batch_size:(ii + 1) * batch_size] = ii
loss_eq = F.cross_entropy(logits_eq, proxy_labels)
l_re = fea_loss + dis_loss * args.w_d
loss_aux = absolute_loss + loss_aux
total_loss = args.lamb * (loss) + loss_aux + loss_eq + l_re
acc = compute_accuracy(logits, query_label)
loss_meter.update(total_loss.item())
acc_meter.update(acc)
tqdm_gen.set_description(f'[train] epo:{epoch:>3} | avg.loss:{loss_meter.avg():.4f} | avg.acc:{acc_meter.avg():.3f} (curr:{acc:.3f})')
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
optimizer.step()
return loss_meter.avg(), acc_meter.avg(), acc_meter.confidence